Compare commits
18 Commits
b0ebb7006e
...
qhy4
| Author | SHA1 | Date | |
|---|---|---|---|
| 43ab0f71b0 | |||
| 5e0e21d91b | |||
| d5bec53f61 | |||
| 508b91f5a2 | |||
| 3101252c25 | |||
| f386a5810b | |||
| 352a79035f | |||
| 9a08e27a19 | |||
| b558856e1e | |||
| dcbcb2c377 | |||
| ff43432ef9 | |||
| afa12ba031 | |||
| bf4d66c874 | |||
| 9347a4ebe5 | |||
| 223a50f9e0 | |||
| 2a6068f9e4 | |||
| 91a9b0febc | |||
| ed637c972b |
21
.claude/settings.local.json
Normal file
21
.claude/settings.local.json
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
{
|
||||||
|
"permissions": {
|
||||||
|
"allow": [
|
||||||
|
"Bash(conda env list:*)",
|
||||||
|
"Bash(mamba env:*)",
|
||||||
|
"Bash(micromamba env list:*)",
|
||||||
|
"Bash(echo:*)",
|
||||||
|
"Bash(git show:*)",
|
||||||
|
"Bash(nvidia-smi:*)",
|
||||||
|
"Bash(conda activate unifolm-wma)",
|
||||||
|
"Bash(conda info:*)",
|
||||||
|
"Bash(direnv allow:*)",
|
||||||
|
"Bash(ls:*)",
|
||||||
|
"Bash(for scenario in unitree_g1_pack_camera unitree_z1_dual_arm_cleanup_pencils unitree_z1_dual_arm_stackbox unitree_z1_dual_arm_stackbox_v2 unitree_z1_stackbox)",
|
||||||
|
"Bash(do for case in case1 case2 case3 case4)",
|
||||||
|
"Bash(done)",
|
||||||
|
"Bash(chmod:*)",
|
||||||
|
"Bash(ln:*)"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
2
.envrc
Normal file
2
.envrc
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
eval "$(conda shell.bash hook 2>/dev/null)"
|
||||||
|
conda activate unifolm-wma
|
||||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -55,7 +55,6 @@ coverage.xml
|
|||||||
*.pot
|
*.pot
|
||||||
|
|
||||||
# Django stuff:
|
# Django stuff:
|
||||||
|
|
||||||
local_settings.py
|
local_settings.py
|
||||||
db.sqlite3
|
db.sqlite3
|
||||||
|
|
||||||
@@ -121,6 +120,7 @@ localTest/
|
|||||||
fig/
|
fig/
|
||||||
figure/
|
figure/
|
||||||
*.mp4
|
*.mp4
|
||||||
|
|
||||||
Data/ControlVAE.yml
|
Data/ControlVAE.yml
|
||||||
Data/Misc
|
Data/Misc
|
||||||
Data/Pretrained
|
Data/Pretrained
|
||||||
@@ -129,4 +129,6 @@ Experiment/checkpoint
|
|||||||
Experiment/log
|
Experiment/log
|
||||||
|
|
||||||
*.ckpt
|
*.ckpt
|
||||||
|
|
||||||
*.0
|
*.0
|
||||||
|
ckpts/unifolm_wma_dual.ckpt.prepared.pt
|
||||||
|
|||||||
135
case4_run.log
135
case4_run.log
@@ -1,135 +0,0 @@
|
|||||||
nohup: ignoring input
|
|
||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
|
|
||||||
__import__("pkg_resources").declare_namespace(__name__)
|
|
||||||
2026-02-08 07:38:45.572744: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
|
||||||
2026-02-08 07:38:45.576864: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
|
||||||
2026-02-08 07:38:45.624825: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
|
||||||
2026-02-08 07:38:45.624883: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
|
||||||
2026-02-08 07:38:45.627150: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
|
||||||
2026-02-08 07:38:45.638316: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
|
||||||
2026-02-08 07:38:45.638803: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
|
||||||
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
|
||||||
2026-02-08 07:38:46.426363: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
|
||||||
[rank: 0] Global seed set to 123
|
|
||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
|
|
||||||
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
|
|
||||||
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
|
|
||||||
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
|
||||||
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
|
||||||
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
|
|
||||||
INFO:root:Loaded ViT-H-14 model config.
|
|
||||||
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
|
|
||||||
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
|
||||||
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
|
||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
|
||||||
checkpoint = torch.load(checkpoint_path, map_location=map_location)
|
|
||||||
INFO:root:Loaded ViT-H-14 model config.
|
|
||||||
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
|
||||||
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
|
||||||
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:86: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
|
||||||
state_dict = torch.load(ckpt, map_location="cpu")
|
|
||||||
>>> model checkpoint loaded.
|
|
||||||
>>> Load pre-trained model ...
|
|
||||||
INFO:root:***** Configing Data *****
|
|
||||||
>>> unitree_z1_stackbox: 1 data samples loaded.
|
|
||||||
>>> unitree_z1_stackbox: data stats loaded.
|
|
||||||
>>> unitree_z1_stackbox: normalizer initiated.
|
|
||||||
>>> unitree_z1_dual_arm_stackbox: 1 data samples loaded.
|
|
||||||
>>> unitree_z1_dual_arm_stackbox: data stats loaded.
|
|
||||||
>>> unitree_z1_dual_arm_stackbox: normalizer initiated.
|
|
||||||
>>> unitree_z1_dual_arm_stackbox_v2: 1 data samples loaded.
|
|
||||||
>>> unitree_z1_dual_arm_stackbox_v2: data stats loaded.
|
|
||||||
>>> unitree_z1_dual_arm_stackbox_v2: normalizer initiated.
|
|
||||||
>>> unitree_z1_dual_arm_cleanup_pencils: 1 data samples loaded.
|
|
||||||
>>> unitree_z1_dual_arm_cleanup_pencils: data stats loaded.
|
|
||||||
>>> unitree_z1_dual_arm_cleanup_pencils: normalizer initiated.
|
|
||||||
>>> unitree_g1_pack_camera: 1 data samples loaded.
|
|
||||||
>>> unitree_g1_pack_camera: data stats loaded.
|
|
||||||
>>> unitree_g1_pack_camera: normalizer initiated.
|
|
||||||
>>> Dataset is successfully loaded ...
|
|
||||||
>>> Generate 16 frames under each generation ...
|
|
||||||
DEBUG:h5py._conv:Creating converter from 3 to 5
|
|
||||||
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
|
|
||||||
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
|
|
||||||
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
|
||||||
|
|
||||||
0%| | 0/7 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
|
|
||||||
proj = linear(q, w, b)
|
|
||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Flash attention support on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:225.)
|
|
||||||
attn_output = scaled_dot_product_attention(
|
|
||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)
|
|
||||||
attn_output = scaled_dot_product_attention(
|
|
||||||
>>> Step 0: generating actions ...
|
|
||||||
>>> Step 0: interacting with world model ...
|
|
||||||
DEBUG:PIL.Image:Importing BlpImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing BmpImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing BufrStubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing CurImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing DcxImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing DdsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing EpsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FitsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FitsStubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FliImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FpxImagePlugin
|
|
||||||
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
|
|
||||||
DEBUG:PIL.Image:Importing FtexImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing GbrImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing GifImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing GribStubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing IcnsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing IcoImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing ImImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing ImtImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing IptcImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing JpegImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing McIdasImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing MicImagePlugin
|
|
||||||
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
|
|
||||||
DEBUG:PIL.Image:Importing MpegImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing MpoImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing MspImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PalmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PcdImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PcxImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PdfImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PixarImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PngImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PpmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PsdImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing QoiImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing SgiImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing SpiderImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing SunImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing TgaImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing TiffImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing WebPImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing WmfImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing XbmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing XpmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
|
||||||
|
|
||||||
14%|█▍ | 1/7 [01:38<09:52, 98.73s/it]
|
|
||||||
29%|██▊ | 2/7 [03:17<08:14, 98.85s/it]
|
|
||||||
43%|████▎ | 3/7 [04:56<06:35, 98.80s/it]
|
|
||||||
57%|█████▋ | 4/7 [06:35<04:56, 98.94s/it]
|
|
||||||
71%|███████▏ | 5/7 [08:14<03:17, 98.93s/it]
|
|
||||||
86%|████████▌ | 6/7 [09:53<01:38, 98.89s/it]
|
|
||||||
100%|██████████| 7/7 [11:31<00:00, 98.81s/it]
|
|
||||||
100%|██████████| 7/7 [11:31<00:00, 98.85s/it]
|
|
||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
|
||||||
>>> Step 1: generating actions ...
|
|
||||||
>>> Step 1: interacting with world model ...
|
|
||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
|
||||||
>>> Step 2: generating actions ...
|
|
||||||
>>> Step 2: interacting with world model ...
|
|
||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
|
||||||
>>> Step 3: generating actions ...
|
|
||||||
>>> Step 3: interacting with world model ...
|
|
||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
|
||||||
>>> Step 4: generating actions ...
|
|
||||||
>>> Step 4: interacting with world model ...
|
|
||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
|
||||||
>>> Step 5: generating actions ...
|
|
||||||
@@ -1 +0,0 @@
|
|||||||
{"framework": "pytorch", "task": "robotics", "allow_remote": true}
|
|
||||||
@@ -222,7 +222,7 @@ data:
|
|||||||
test:
|
test:
|
||||||
target: unifolm_wma.data.wma_data.WMAData
|
target: unifolm_wma.data.wma_data.WMAData
|
||||||
params:
|
params:
|
||||||
data_dir: '/mnt/ASC1637/unifolm-world-model-action/examples/world_model_interaction_prompts'
|
data_dir: '/home/qhy/unifolm-world-model-action/examples/world_model_interaction_prompts'
|
||||||
video_length: ${model.params.wma_config.params.temporal_length}
|
video_length: ${model.params.wma_config.params.temporal_length}
|
||||||
frame_stride: 2
|
frame_stride: 2
|
||||||
load_raw_resolution: True
|
load_raw_resolution: True
|
||||||
|
|||||||
122
docs/architecture_overview.typ
Normal file
122
docs/architecture_overview.typ
Normal file
@@ -0,0 +1,122 @@
|
|||||||
|
== Task Comprehension: Diffusion Model and UnifoLM-WMA
|
||||||
|
|
||||||
|
This section provides a comprehensive overview of the UnifoLM-WMA-0 deep learning architecture, serving as a practical foundation for the optimization strategies discussed in subsequent sections.
|
||||||
|
|
||||||
|
=== Overall Inference Pipeline
|
||||||
|
|
||||||
|
UnifoLM-WMA-0 is Unitree Robotics' open-source World-Model-Action framework. Its core task is to predict future video frame sequences along with the corresponding robot action and state trajectories, given a current observation image and a text instruction. The model operates in an interactive simulation mode: each iteration consumes the previous prediction as input and generates the next segment of video and actions, thereby forming a closed-loop rollout. A single iteration of this pipeline can be decomposed into four sequential stages — condition encoding, VAE encoding, DDIM diffusion sampling, and VAE decoding — each of which is described below.
|
||||||
|
|
||||||
|
==== Condition Encoding
|
||||||
|
|
||||||
|
The condition encoding stage transforms raw multi-modal inputs into a unified context vector that guides the diffusion denoising process, through three parallel encoding paths. On the image side, the input observation image (320#sym.times 512) is processed by a frozen OpenCLIP ViT-H-14 vision encoder, then compressed through a Resampler — a Perceiver-based cross-attention module (4 layers, 12 heads, dim\_head=64, embed\_dim 1280 #sym.arrow 1024) — into 16 image condition tokens per frame, yielding $16 times T = 256$ image tokens for T=16 frames.
|
||||||
|
|
||||||
|
On the text side, the instruction is encoded by a frozen OpenCLIP text encoder (`FrozenOpenCLIPEmbedder`, penultimate layer output) into 77 tokens of dimension 1024, computed once and reused across all DDIM steps. On the state side, the robot proprioceptive state (dim 16) is mapped through a SATokenProjector (Perceiver Attention, 1 layer, 16 heads, dim\_head=64, 16 learnable queries) into 16 tokens of dimension 1024.
|
||||||
|
|
||||||
|
These three token sets are concatenated to form the unified context vector: `[agent_state(2) | agent_action(16) | text(77) | image(256)]`, totaling 351 tokens per cross-attention operation.
|
||||||
|
|
||||||
|
==== VAE Encoding
|
||||||
|
|
||||||
|
The observation images are encoded into a compact latent space through an AutoencoderKL (`autoencoder.py`) — a variational autoencoder regularized by KL divergence. The encoder follows a convolutional architecture with 4-level channel multipliers [1, 2, 4, 4] (base channels ch=128, yielding channel widths [128, 256, 512, 512]), 2 residual blocks per level, and a latent channel count of z\_channels=4. The input RGB frames at resolution 320#sym.times 512 are encoded into latent representations at 1/8 spatial resolution, producing tensors of shape `(B, 4, T, 40, 64)`.
|
||||||
|
|
||||||
|
A critical configuration parameter is `perframe_ae=True`, which means the VAE processes each of the T=16 frames independently rather than as a 3D volume. While this per-frame strategy avoids the memory overhead of volumetric convolutions, it introduces a sequential loop of T forward passes through the encoder — a point worth noting for latency optimization. The latent representations are scaled by a fixed factor of `scale_factor=0.18215` before being fed into the diffusion process.
|
||||||
|
|
||||||
|
==== DDIM Diffusion Sampling
|
||||||
|
|
||||||
|
This is the core time-consuming part of inference. A DDIM (Denoising Diffusion Implicit Models) sampler (`ddim.py`) is employed with a default of 50 denoising steps. The diffusion process is parameterized with v-prediction (`parameterization="v"`), 1000 training timesteps, and a linear beta schedule from `linear_start=0.00085` to `linear_end=0.012`, with zero-SNR terminal rescaling enabled (`rescale_betas_zero_snr=True`) and dynamic rescaling applied at `base_scale=0.7` to stabilize generation quality.
|
||||||
|
|
||||||
|
Unlike standard video diffusion models that only predict denoised video latents, UnifoLM-WMA simultaneously produces three outputs per step: a video latent prediction `y` of shape `(B, 4, T, 40, 64)`, an action trajectory prediction `a_y` of shape `(B, T, 16)`, and a state trajectory prediction `s_y` of shape `(B, T, 16)`. The three predictions share the same diffusion timestep but employ heterogeneous noise schedules — the video stream uses the DDPM schedule with v-prediction, while the action and state streams use a `DDIMScheduler` from the `diffusers` library with epsilon-prediction and a `squaredcos_cap_v2` beta schedule. This design allows each modality to adopt its optimal denoising strategy.
|
||||||
|
|
||||||
|
The sampler also supports classifier-free guidance with `unconditional_guidance_scale` and guidance rescaling, applied only to the video stream to balance generation quality and diversity.
|
||||||
|
|
||||||
|
==== VAE Decoding
|
||||||
|
|
||||||
|
After the DDIM sampling loop completes, the denoised video latent tensor $x_0$ of shape `(B, 4, T, 40, 64)` is decoded back to RGB pixel space through the AutoencoderKL decoder. Due to the `perframe_ae=True` configuration, decoding is likewise performed frame-by-frame: each of the T=16 latent frames is individually inverse-scaled by $1 slash "scale_factor"$, passed through the decoder's convolutional transpose layers, and reconstructed to a 320#sym.times 512 RGB frame.
|
||||||
|
|
||||||
|
In the interactive simulation mode, the decoded video serves a dual purpose — providing the observation image for the next iteration's condition encoding (only the first `exe_steps` frames are needed) and producing the final output video for visualization and evaluation. The action and state trajectories predicted by the DDIM loop are directly used for robot control without further decoding.
|
||||||
|
|
||||||
|
=== WMAModel Backbone: Dual-UNet Collaborative Architecture
|
||||||
|
|
||||||
|
The WMAModel (`wma_model.py:326`) is the core neural network invoked at every DDIM step, employing a unique dual-UNet collaborative architecture that jointly predicts video, actions, and states within a single forward pass. This tightly-coupled design enables the action and state predictions to directly leverage the rich spatiotemporal features extracted by the video generation backbone, rather than treating them as independent prediction heads.
|
||||||
|
|
||||||
|
==== Video UNet
|
||||||
|
|
||||||
|
The primary backbone is a 2D convolution-based UNet with temporal extensions. Its key configuration is summarized in the following table:
|
||||||
|
|
||||||
|
#figure(
|
||||||
|
table(
|
||||||
|
columns: (3fr, 5fr),
|
||||||
|
[*Parameter*], [*Value*],
|
||||||
|
[Input / Output channels], [8 (4 latent + 4 conditioning) / 4],
|
||||||
|
[Base model channels], [320],
|
||||||
|
[Channel multipliers], [\[1, 2, 4, 4\] #sym.arrow widths \[320, 640, 1280, 1280\]],
|
||||||
|
[Residual blocks per level], [2],
|
||||||
|
[Attention resolutions], [\[4, 2, 1\] (3 of 4 resolution levels)],
|
||||||
|
[Attention head channels], [64],
|
||||||
|
[Transformer depth], [1 per attention resolution],
|
||||||
|
[Context dimension], [1024],
|
||||||
|
[Temporal length], [16 frames],
|
||||||
|
),
|
||||||
|
caption: [Video UNet configuration parameters.],
|
||||||
|
)
|
||||||
|
|
||||||
|
The UNet follows the classic encoder-middle-decoder structure with skip connections. At each attention-enabled resolution level, every ResBlock is followed by two transformer modules: a SpatialTransformer that performs spatial self-attention among all $H times W$ tokens within each frame followed by cross-attention with the 351-token context vector, and a TemporalTransformer that performs self-attention among T=16 time-step tokens at each spatial position (configured with `temporal_selfatt_only=True`, i.e., no cross-attention).
|
||||||
|
|
||||||
|
During the forward pass, intermediate feature maps are collected after each Downsample layer and the middle block, reshaped from $(B times T, C, H, W)$ to $(B, T, C, H, W)$, accumulating 10 multi-scale feature maps in `hs_a` — the bridge to the Action/State UNets.
|
||||||
|
|
||||||
|
==== Action UNet and State UNet
|
||||||
|
|
||||||
|
The Action UNet (`conditional_unet1d.py`) is a 1D convolutional UNet specifically designed for predicting robot action trajectories. Its configuration is as follows:
|
||||||
|
|
||||||
|
#figure(
|
||||||
|
table(
|
||||||
|
columns: (3fr, 5fr),
|
||||||
|
[*Parameter*], [*Value*],
|
||||||
|
[Input dimension], [16 (agent\_action\_dim)],
|
||||||
|
[Down channel widths], [\[256, 512, 1024, 2048\]],
|
||||||
|
[Kernel size], [5],
|
||||||
|
[GroupNorm groups], [8],
|
||||||
|
[Diffusion step embedding dim], [128],
|
||||||
|
[Horizon], [16],
|
||||||
|
[Action projection dim], [32],
|
||||||
|
),
|
||||||
|
caption: [Action UNet (ConditionalUnet1D) configuration parameters.],
|
||||||
|
)
|
||||||
|
|
||||||
|
The Action UNet receives the 10 `hs_a` feature maps from the Video UNet as visual conditioning. The conditioning pipeline involves three stages: (1) SpatialSoftmax compresses each 2D feature map into keypoint coordinates $(B times T, C, 2)$; (2) the compressed features are concatenated with the diffusion timestep embedding and observation encoding (ResNet-18 `MultiImageObsEncoder`), then injected via FiLM modulation to produce per-channel scale/bias for the 1D convolution blocks; (3) `ActionLatentImageCrossAttention` enables action tokens to cross-attend to the Video UNet's spatiotemporal features, allowing visually-grounded action planning.
|
||||||
|
|
||||||
|
The input action tensor $(B, T, 16)$ is projected to act\_proj\_dim=32, processed through the 1D UNet, then projected back to $(B, T, 16)$.
|
||||||
|
|
||||||
|
The State UNet is an identical `ConditionalUnet1D` instance with the same hyperparameters, operating on the state tensor `x_state` $(B, T, 16)$ instead of the action tensor.
|
||||||
|
|
||||||
|
A critical optimization observation: the Action and State UNets are computationally independent — sharing read-only inputs with no data dependencies. The original code executes them sequentially, leaving significant room for CUDA stream parallelization.
|
||||||
|
|
||||||
|
=== Multi-Level Design of Attention Mechanisms
|
||||||
|
|
||||||
|
The attention mechanisms in UnifoLM-WMA constitute the core computational bottleneck of inference. Their design encompasses four distinct levels, each serving a different purpose in the model's spatiotemporal reasoning, and understanding their structure is essential for identifying optimization opportunities.
|
||||||
|
|
||||||
|
The first level is *spatial self-attention* within the SpatialTransformer. For a latent frame at resolution $H times W$, the token count is $H times W$ (e.g., $40 times 64 = 2560$ at the highest resolution). Implemented via xformers `memory_efficient_attention`, reducing peak memory from $O(N^2)$ to $O(N)$. Q/K/V use bias-free linear layers with head count = channel\_dim / num\_head\_channels (e.g., 1280/64 = 20 heads).
|
||||||
|
|
||||||
|
The second level is *multi-source cross-attention*, the most distinctive design in UnifoLM-WMA. The unified context vector is split into four semantic sources, each with dedicated K/V projection layers:
|
||||||
|
|
||||||
|
#figure(
|
||||||
|
table(
|
||||||
|
columns: (2fr, 1fr, 3fr, 2fr),
|
||||||
|
[*Source*], [*Tokens*], [*K/V Projections*], [*Scale*],
|
||||||
|
[Text], [77], [`to_k` / `to_v` (shared base)], [1.0],
|
||||||
|
[Image], [16#sym.times T], [`to_k_ip` / `to_v_ip`], [`image_cross_attention_scale`],
|
||||||
|
[Agent state], [2], [`to_k_as` / `to_v_as`], [`agent_state_cross_attention_scale`],
|
||||||
|
[Agent action], [16], [`to_k_aa` / `to_v_aa`], [`agent_action_cross_attention_scale`],
|
||||||
|
),
|
||||||
|
caption: [Multi-source cross-attention configuration.],
|
||||||
|
)
|
||||||
|
|
||||||
|
The Query vector Q is always derived from the video latent features via `to_q`. For each of the four sources, independent attention scores are computed — $"softmax"(Q dot K_i^T \/ sqrt(d)) dot V_i$ — producing four separate attention outputs. These outputs are then combined via weighted summation:
|
||||||
|
|
||||||
|
$ "out" = "out"_"text" + alpha_"img" dot "out"_"ip" + alpha_"state" dot "out"_"as" + alpha_"action" dot "out"_"aa" $
|
||||||
|
|
||||||
|
In the current configuration, `cross_attention_scale_learnable=False` (fixed scales). This decoupled design adds 8 extra linear layers versus standard single-source cross-attention, creating opportunities for KV fusion optimization.
|
||||||
|
|
||||||
|
The third level is *temporal self-attention* within the TemporalTransformer. The input $(B, C, T, H, W)$ is reshaped to $(B times H times W, C, T)$, so each spatial position becomes an independent batch element and T=16 time steps form the token sequence. Supports relative position encoding via a `RelativePosition` module and optional causal masks; current configuration uses bidirectional temporal attention.
|
||||||
|
|
||||||
|
The fourth level is *action-latent-image cross-attention* in the `ActionLatentImageCrossAttention` module. Action tokens $(B, "action_dim", "act_proj_dim")$ as Query cross-attend to Video UNet features reshaped to $(B, T times H times W, C)$ as Key/Value. A `BasicTransformerBlock` (depth=1) performs action self-attention then cross-attention to video features, with zero-initialized `proj_out` and residual connection. This mechanism is the key bridge enabling the action head to access the visual world model's internal representations.
|
||||||
|
|
||||||
21
env.sh
21
env.sh
@@ -1,21 +0,0 @@
|
|||||||
# Note: This script should be sourced, not executed
|
|
||||||
# Usage: source env.sh
|
|
||||||
#
|
|
||||||
# If you need render group permissions, run this first:
|
|
||||||
# newgrp render
|
|
||||||
# Then source this script:
|
|
||||||
# source env.sh
|
|
||||||
|
|
||||||
# Initialize conda
|
|
||||||
source /mnt/ASC1637/miniconda3/etc/profile.d/conda.sh
|
|
||||||
|
|
||||||
# Activate conda environment
|
|
||||||
conda activate unifolm-wma-o
|
|
||||||
|
|
||||||
# Set HuggingFace cache directories
|
|
||||||
export HF_HOME=/mnt/ASC1637/hf_home
|
|
||||||
export HUGGINGFACE_HUB_CACHE=/mnt/ASC1637/hf_home/hub
|
|
||||||
|
|
||||||
echo "Environment configured successfully"
|
|
||||||
echo "Conda environment: unifolm-wma-o"
|
|
||||||
echo "HF_HOME: $HF_HOME"
|
|
||||||
@@ -1,217 +0,0 @@
|
|||||||
TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 CUDA_VISIBLE_DEVICES=0 python scripts/evaluation/profile_unet.py --ckpt_path ckpts/unifolm_wma_dual_mix_bf16.ckpt --config configs/inference/world_model_interaction.yaml
|
|
||||||
|
|
||||||
|
|
||||||
==================================================================================================================================
|
|
||||||
FLOPS BY ATen OPERATOR (FlopCounterMode)
|
|
||||||
==================================================================================================================================
|
|
||||||
ATen Op | GFLOPS | % of Total
|
|
||||||
-------------------------------------------------------
|
|
||||||
convolution | 6185.17 | 46.4%
|
|
||||||
addmm | 4411.17 | 33.1%
|
|
||||||
mm | 1798.34 | 13.5%
|
|
||||||
bmm | 949.54 | 7.1%
|
|
||||||
|
|
||||||
==================================================================================================================================
|
|
||||||
FLOPS BY MODULE (FlopCounterMode)
|
|
||||||
==================================================================================================================================
|
|
||||||
Module | GFLOPS | % of Total
|
|
||||||
------------------------------------------------------------------------------------------
|
|
||||||
Global | 13344.23 | 100.0%
|
|
||||||
DiffusionWrapper | 13344.23 | 100.0%
|
|
||||||
DiffusionWrapper.diffusion_model | 13344.23 | 100.0%
|
|
||||||
DiffusionWrapper.diffusion_model.output_blocks.8 | 997.87 | 7.5%
|
|
||||||
DiffusionWrapper.diffusion_model.output_blocks.5 | 992.91 | 7.4%
|
|
||||||
DiffusionWrapper.diffusion_model.output_blocks.9 | 941.81 | 7.1%
|
|
||||||
DiffusionWrapper.diffusion_model.output_blocks.10 | 857.93 | 6.4%
|
|
||||||
DiffusionWrapper.diffusion_model.output_blocks.11 | 857.93 | 6.4%
|
|
||||||
DiffusionWrapper.diffusion_model.output_blocks.6 | 821.71 | 6.2%
|
|
||||||
DiffusionWrapper.diffusion_model.input_blocks.1 | 765.65 | 5.7%
|
|
||||||
DiffusionWrapper.diffusion_model.input_blocks.2 | 765.65 | 5.7%
|
|
||||||
DiffusionWrapper.diffusion_model.output_blocks.7 | 737.82 | 5.5%
|
|
||||||
DiffusionWrapper.diffusion_model.output_blocks.3 | 732.87 | 5.5%
|
|
||||||
DiffusionWrapper.diffusion_model.output_blocks.4 | 732.87 | 5.5%
|
|
||||||
DiffusionWrapper.diffusion_model.input_blocks.5 | 645.55 | 4.8%
|
|
||||||
DiffusionWrapper.diffusion_model.input_blocks.8 | 640.59 | 4.8%
|
|
||||||
DiffusionWrapper.diffusion_model.input_blocks.4 | 611.99 | 4.6%
|
|
||||||
DiffusionWrapper.diffusion_model.input_blocks.7 | 607.04 | 4.5%
|
|
||||||
DiffusionWrapper.diffusion_model.init_attn.0 | 459.02 | 3.4%
|
|
||||||
DiffusionWrapper.diffusion_model.init_attn | 459.02 | 3.4%
|
|
||||||
nWrapper.diffusion_model.init_attn.0.transformer_blocks.0 | 432.18 | 3.2%
|
|
||||||
DiffusionWrapper.diffusion_model.output_blocks.6.0 | 427.85 | 3.2%
|
|
||||||
DiffusionWrapper.diffusion_model.output_blocks.9.0 | 427.83 | 3.2%
|
|
||||||
DiffusionWrapper.diffusion_model.output_blocks.3.0 | 343.99 | 2.6%
|
|
||||||
DiffusionWrapper.diffusion_model.output_blocks.4.0 | 343.99 | 2.6%
|
|
||||||
DiffusionWrapper.diffusion_model.output_blocks.7.0 | 343.96 | 2.6%
|
|
||||||
DiffusionWrapper.diffusion_model.output_blocks.10.0 | 343.95 | 2.6%
|
|
||||||
DiffusionWrapper.diffusion_model.output_blocks.11.0 | 343.95 | 2.6%
|
|
||||||
DiffusionWrapper.diffusion_model.input_blocks.1.1 | 327.75 | 2.5%
|
|
||||||
DiffusionWrapper.diffusion_model.input_blocks.2.1 | 327.75 | 2.5%
|
|
||||||
|
|
||||||
==================================================================================================================================
|
|
||||||
SUMMARY
|
|
||||||
==================================================================================================================================
|
|
||||||
Total CUDA time: 761.4 ms
|
|
||||||
Matmul CUDA time: 404.2 ms (53.1%)
|
|
||||||
Non-matmul CUDA time: 357.1 ms (46.9%)
|
|
||||||
Total FLOPS (FlopCounter): 13344.23 GFLOPS
|
|
||||||
Matmul throughput: 33.01 TFLOPS/s (54.1% of BF16 peak)
|
|
||||||
Overall throughput: 17.53 TFLOPS/s (28.7% of BF16 peak)
|
|
||||||
GPU peak (BF16): 61.0 TFLOPS
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
==================================================================================================================================
|
|
||||||
FLOPS BY ATen OPERATOR (FlopCounterMode)
|
|
||||||
==================================================================================================================================
|
|
||||||
ATen Op | GFLOPS | % of Total
|
|
||||||
-------------------------------------------------------
|
|
||||||
convolution | 6185.17 | 46.4%
|
|
||||||
addmm | 4411.17 | 33.1%
|
|
||||||
mm | 1798.34 | 13.5%
|
|
||||||
bmm | 949.54 | 7.1%
|
|
||||||
|
|
||||||
==================================================================================================================================
|
|
||||||
FLOPS BY MODULE (FlopCounterMode)
|
|
||||||
==================================================================================================================================
|
|
||||||
Module | GFLOPS | % of Total
|
|
||||||
------------------------------------------------------------------------------------------
|
|
||||||
DiffusionWrapper | 13344.23 | 100.0%
|
|
||||||
Global | 13344.23 | 100.0%
|
|
||||||
DiffusionWrapper.diffusion_model | 13344.23 | 100.0%
|
|
||||||
DiffusionWrapper.diffusion_model.output_blocks.8 | 997.87 | 7.5%
|
|
||||||
DiffusionWrapper.diffusion_model.output_blocks.5 | 992.91 | 7.4%
|
|
||||||
DiffusionWrapper.diffusion_model.output_blocks.9 | 941.81 | 7.1%
|
|
||||||
DiffusionWrapper.diffusion_model.output_blocks.10 | 857.93 | 6.4%
|
|
||||||
DiffusionWrapper.diffusion_model.output_blocks.11 | 857.93 | 6.4%
|
|
||||||
DiffusionWrapper.diffusion_model.output_blocks.6 | 821.71 | 6.2%
|
|
||||||
DiffusionWrapper.diffusion_model.input_blocks.1 | 765.65 | 5.7%
|
|
||||||
DiffusionWrapper.diffusion_model.input_blocks.2 | 765.65 | 5.7%
|
|
||||||
DiffusionWrapper.diffusion_model.output_blocks.7 | 737.82 | 5.5%
|
|
||||||
DiffusionWrapper.diffusion_model.output_blocks.3 | 732.87 | 5.5%
|
|
||||||
DiffusionWrapper.diffusion_model.output_blocks.4 | 732.87 | 5.5%
|
|
||||||
DiffusionWrapper.diffusion_model.input_blocks.5 | 645.55 | 4.8%
|
|
||||||
DiffusionWrapper.diffusion_model.input_blocks.8 | 640.59 | 4.8%
|
|
||||||
DiffusionWrapper.diffusion_model.input_blocks.4 | 611.99 | 4.6%
|
|
||||||
DiffusionWrapper.diffusion_model.input_blocks.7 | 607.04 | 4.5%
|
|
||||||
DiffusionWrapper.diffusion_model.init_attn | 459.02 | 3.4%
|
|
||||||
DiffusionWrapper.diffusion_model.init_attn.0 | 459.02 | 3.4%
|
|
||||||
nWrapper.diffusion_model.init_attn.0.transformer_blocks.0 | 432.18 | 3.2%
|
|
||||||
DiffusionWrapper.diffusion_model.output_blocks.6.0 | 427.85 | 3.2%
|
|
||||||
DiffusionWrapper.diffusion_model.output_blocks.9.0 | 427.83 | 3.2%
|
|
||||||
DiffusionWrapper.diffusion_model.output_blocks.3.0 | 343.99 | 2.6%
|
|
||||||
DiffusionWrapper.diffusion_model.output_blocks.4.0 | 343.99 | 2.6%
|
|
||||||
DiffusionWrapper.diffusion_model.output_blocks.7.0 | 343.96 | 2.6%
|
|
||||||
DiffusionWrapper.diffusion_model.output_blocks.10.0 | 343.95 | 2.6%
|
|
||||||
DiffusionWrapper.diffusion_model.output_blocks.11.0 | 343.95 | 2.6%
|
|
||||||
DiffusionWrapper.diffusion_model.input_blocks.1.1 | 327.75 | 2.5%
|
|
||||||
DiffusionWrapper.diffusion_model.input_blocks.2.1 | 327.75 | 2.5%
|
|
||||||
|
|
||||||
==================================================================================================================================
|
|
||||||
SUMMARY
|
|
||||||
==================================================================================================================================
|
|
||||||
Total CUDA time: 707.1 ms
|
|
||||||
Matmul CUDA time: 403.1 ms (57.0%)
|
|
||||||
Non-matmul CUDA time: 304.0 ms (43.0%)
|
|
||||||
Total FLOPS (FlopCounter): 13344.23 GFLOPS
|
|
||||||
Matmul throughput: 33.11 TFLOPS/s (54.3% of BF16 peak)
|
|
||||||
Overall throughput: 18.87 TFLOPS/s (30.9% of BF16 peak)
|
|
||||||
GPU peak (BF16): 61.0 TFLOPS
|
|
||||||
(unifolm-wma) ASC1637@wx-ms-w7900d-0033:/mnt/ASC1637/unifolm-world-model-action$
|
|
||||||
|
|
||||||
========================================================================
|
|
||||||
TABLE 1: STAGE TIMING
|
|
||||||
========================================================================
|
|
||||||
Stage Mean(ms) Std %
|
|
||||||
------------------------------------------------------------------------
|
|
||||||
1_Image_Embedding 29.5 0.16 0.1%
|
|
||||||
2_VAE_Encode 51.3 0.06 0.1%
|
|
||||||
3_Text_Conditioning 14.7 0.18 0.0%
|
|
||||||
4_Projectors 0.2 0.03 0.0%
|
|
||||||
5_DDIM_Loop 33392.5 3.21 97.3%
|
|
||||||
6_VAE_Decode 808.4 1.00 2.4%
|
|
||||||
7_Post_Process 15.8 0.56 0.0%
|
|
||||||
------------------------------------------------------------------------
|
|
||||||
TOTAL 34312.4
|
|
||||||
|
|
||||||
================================================================================
|
|
||||||
TABLE 2: UNET SUB-MODULE BREAKDOWN
|
|
||||||
================================================================================
|
|
||||||
Module Type Total(ms) Count Per-call %
|
|
||||||
--------------------------------------------------------------------------------
|
|
||||||
ResBlock 10256.3 1100 9.32 23.2%
|
|
||||||
SpatialTransformer 9228.2 800 11.54 20.9%
|
|
||||||
CrossAttention 8105.8 3300 2.46 18.3%
|
|
||||||
ConditionalUnet1D 6409.5 100 64.10 14.5%
|
|
||||||
TemporalTransformer 5847.0 850 6.88 13.2%
|
|
||||||
FeedForward 4338.1 1650 2.63 9.8%
|
|
||||||
UNet.out 73.8 50 1.48 0.2%
|
|
||||||
--------------------------------------------------------------------------------
|
|
||||||
TOTAL (hooked) 44258.7
|
|
||||||
|
|
||||||
==========================================================================================
|
|
||||||
TABLE 2b: PER-UNET-BLOCK TIMING (coarse modules, per DDIM loop)
|
|
||||||
==========================================================================================
|
|
||||||
Block Total(ms) % Breakdown
|
|
||||||
------------------------------------------------------------------------------------------
|
|
||||||
input_blocks.1 3376.2 7.6% SpatialTransformer=1101, CrossAttention=990, ResBlock=543, TemporalTransformer=454, FeedForward=288
|
|
||||||
input_blocks.2 3374.0 7.6% SpatialTransformer=1100, CrossAttention=991, ResBlock=540, TemporalTransformer=455, FeedForward=288
|
|
||||||
input_blocks.4 1592.4 3.6% SpatialTransformer=394, ResBlock=374, CrossAttention=303, TemporalTransformer=272, FeedForward=249
|
|
||||||
input_blocks.5 1642.5 3.7% ResBlock=425, SpatialTransformer=397, CrossAttention=303, TemporalTransformer=271, FeedForward=247
|
|
||||||
input_blocks.7 1469.0 3.3% ResBlock=416, SpatialTransformer=324, FeedForward=251, CrossAttention=240, TemporalTransformer=237
|
|
||||||
input_blocks.8 1543.7 3.5% ResBlock=491, SpatialTransformer=325, FeedForward=250, CrossAttention=240, TemporalTransformer=238
|
|
||||||
input_blocks.10 217.5 0.5% ResBlock=218
|
|
||||||
input_blocks.11 216.8 0.5% ResBlock=217
|
|
||||||
middle_block 848.9 1.9% ResBlock=434, SpatialTransformer=151, CrossAttention=134, TemporalTransformer=69, FeedForward=61
|
|
||||||
output_blocks.0 303.2 0.7% ResBlock=303
|
|
||||||
output_blocks.1 303.1 0.7% ResBlock=303
|
|
||||||
output_blocks.2 302.8 0.7% ResBlock=303
|
|
||||||
output_blocks.3 1734.8 3.9% ResBlock=687, SpatialTransformer=322, FeedForward=249, CrossAttention=239, TemporalTransformer=237
|
|
||||||
output_blocks.4 1739.8 3.9% ResBlock=688, SpatialTransformer=323, FeedForward=251, CrossAttention=239, TemporalTransformer=238
|
|
||||||
output_blocks.5 1622.3 3.7% ResBlock=570, SpatialTransformer=324, FeedForward=251, CrossAttention=239, TemporalTransformer=238
|
|
||||||
output_blocks.6 1881.0 4.3% ResBlock=664, SpatialTransformer=393, CrossAttention=301, TemporalTransformer=272, FeedForward=250
|
|
||||||
output_blocks.7 1768.0 4.0% ResBlock=554, SpatialTransformer=393, CrossAttention=301, TemporalTransformer=272, FeedForward=249
|
|
||||||
output_blocks.8 1688.7 3.8% ResBlock=474, SpatialTransformer=393, CrossAttention=301, TemporalTransformer=272, FeedForward=249
|
|
||||||
output_blocks.9 3558.6 8.0% SpatialTransformer=1096, CrossAttention=992, ResBlock=727, TemporalTransformer=454, FeedForward=290
|
|
||||||
output_blocks.10 3492.8 7.9% SpatialTransformer=1096, CrossAttention=992, ResBlock=662, TemporalTransformer=454, FeedForward=289
|
|
||||||
output_blocks.11 3493.3 7.9% SpatialTransformer=1096, CrossAttention=992, ResBlock=662, TemporalTransformer=454, FeedForward=289
|
|
||||||
out 73.8 0.2% UNet.out=74
|
|
||||||
action_unet 3212.0 7.3% ConditionalUnet1D=3212
|
|
||||||
state_unet 3197.6 7.2% ConditionalUnet1D=3198
|
|
||||||
other 1606.2 3.6% TemporalTransformer=960, FeedForward=337, CrossAttention=309
|
|
||||||
------------------------------------------------------------------------------------------
|
|
||||||
TOTAL 44258.7
|
|
||||||
|
|
||||||
======================================================================
|
|
||||||
TABLE 2c: ATTENTION vs FEEDFORWARD (deep hooks)
|
|
||||||
======================================================================
|
|
||||||
Component Total(ms) %
|
|
||||||
----------------------------------------------------------------------
|
|
||||||
CrossAttention 8105.8 65.1%
|
|
||||||
FeedForward 4338.1 34.9%
|
|
||||||
----------------------------------------------------------------------
|
|
||||||
TOTAL (attn+ff) 12443.9
|
|
||||||
|
|
||||||
==================================================
|
|
||||||
TABLE 3: MEMORY SUMMARY
|
|
||||||
==================================================
|
|
||||||
Initial allocated: 11.82 GB
|
|
||||||
Peak allocated: 14.43 GB
|
|
||||||
Delta (pipeline): 2.61 GB
|
|
||||||
|
|
||||||
============================================================
|
|
||||||
TABLE 4: THROUGHPUT
|
|
||||||
============================================================
|
|
||||||
Total pipeline latency: 34312.4 ms
|
|
||||||
DDIM loop latency: 33392.5 ms
|
|
||||||
DDIM steps: 50
|
|
||||||
CFG scale: 1.0 (1x UNet/step)
|
|
||||||
UNet forward calls: 50
|
|
||||||
Per DDIM step: 667.9 ms
|
|
||||||
Per UNet forward: 667.9 ms
|
|
||||||
VAE encode bandwidth: 0.1 GB/s (peak HBM: 864.0 GB/s)
|
|
||||||
VAE decode bandwidth: 0.0 GB/s (peak HBM: 864.0 GB/s)
|
|
||||||
GPU BF16 peak: 61.0 TFLOPS
|
|
||||||
|
|
||||||
Done.
|
|
||||||
150
run.log
150
run.log
@@ -1,150 +0,0 @@
|
|||||||
nohup: ignoring input
|
|
||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
|
|
||||||
__import__("pkg_resources").declare_namespace(__name__)
|
|
||||||
2026-02-08 08:15:49.934949: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
|
||||||
2026-02-08 08:15:49.937974: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
|
||||||
2026-02-08 08:15:49.969069: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
|
||||||
2026-02-08 08:15:49.969100: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
|
||||||
2026-02-08 08:15:49.970909: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
|
||||||
2026-02-08 08:15:49.979005: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
|
||||||
2026-02-08 08:15:49.979255: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
|
||||||
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
|
||||||
2026-02-08 08:15:50.597743: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
|
||||||
[rank: 0] Global seed set to 123
|
|
||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
|
|
||||||
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
|
|
||||||
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
|
|
||||||
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
|
||||||
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
|
||||||
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
|
|
||||||
INFO:root:Loaded ViT-H-14 model config.
|
|
||||||
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
|
|
||||||
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
|
||||||
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
|
||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
|
||||||
checkpoint = torch.load(checkpoint_path, map_location=map_location)
|
|
||||||
INFO:root:Loaded ViT-H-14 model config.
|
|
||||||
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
|
||||||
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
|
||||||
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:86: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
|
||||||
state_dict = torch.load(ckpt, map_location="cpu")
|
|
||||||
>>> model checkpoint loaded.
|
|
||||||
>>> Load pre-trained model ...
|
|
||||||
INFO:root:***** Configing Data *****
|
|
||||||
>>> unitree_z1_stackbox: 1 data samples loaded.
|
|
||||||
>>> unitree_z1_stackbox: data stats loaded.
|
|
||||||
>>> unitree_z1_stackbox: normalizer initiated.
|
|
||||||
>>> unitree_z1_dual_arm_stackbox: 1 data samples loaded.
|
|
||||||
>>> unitree_z1_dual_arm_stackbox: data stats loaded.
|
|
||||||
>>> unitree_z1_dual_arm_stackbox: normalizer initiated.
|
|
||||||
>>> unitree_z1_dual_arm_stackbox_v2: 1 data samples loaded.
|
|
||||||
>>> unitree_z1_dual_arm_stackbox_v2: data stats loaded.
|
|
||||||
>>> unitree_z1_dual_arm_stackbox_v2: normalizer initiated.
|
|
||||||
>>> unitree_z1_dual_arm_cleanup_pencils: 1 data samples loaded.
|
|
||||||
>>> unitree_z1_dual_arm_cleanup_pencils: data stats loaded.
|
|
||||||
>>> unitree_z1_dual_arm_cleanup_pencils: normalizer initiated.
|
|
||||||
>>> unitree_g1_pack_camera: 1 data samples loaded.
|
|
||||||
>>> unitree_g1_pack_camera: data stats loaded.
|
|
||||||
>>> unitree_g1_pack_camera: normalizer initiated.
|
|
||||||
>>> Dataset is successfully loaded ...
|
|
||||||
>>> Generate 16 frames under each generation ...
|
|
||||||
DEBUG:h5py._conv:Creating converter from 3 to 5
|
|
||||||
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
|
|
||||||
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
|
|
||||||
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
|
||||||
|
|
||||||
0%| | 0/12 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
|
|
||||||
proj = linear(q, w, b)
|
|
||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Flash attention support on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:225.)
|
|
||||||
attn_output = scaled_dot_product_attention(
|
|
||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)
|
|
||||||
attn_output = scaled_dot_product_attention(
|
|
||||||
>>> Step 0: generating actions ...
|
|
||||||
>>> Step 0: interacting with world model ...
|
|
||||||
DEBUG:PIL.Image:Importing BlpImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing BmpImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing BufrStubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing CurImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing DcxImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing DdsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing EpsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FitsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FitsStubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FliImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FpxImagePlugin
|
|
||||||
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
|
|
||||||
DEBUG:PIL.Image:Importing FtexImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing GbrImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing GifImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing GribStubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing IcnsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing IcoImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing ImImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing ImtImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing IptcImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing JpegImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing McIdasImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing MicImagePlugin
|
|
||||||
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
|
|
||||||
DEBUG:PIL.Image:Importing MpegImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing MpoImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing MspImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PalmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PcdImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PcxImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PdfImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PixarImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PngImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PpmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PsdImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing QoiImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing SgiImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing SpiderImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing SunImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing TgaImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing TiffImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing WebPImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing WmfImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing XbmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing XpmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
|
||||||
|
|
||||||
8%|▊ | 1/12 [01:37<17:51, 97.37s/it]
|
|
||||||
17%|█▋ | 2/12 [03:14<16:13, 97.31s/it]
|
|
||||||
25%|██▌ | 3/12 [04:51<14:35, 97.26s/it]
|
|
||||||
33%|███▎ | 4/12 [06:29<12:58, 97.25s/it]
|
|
||||||
42%|████▏ | 5/12 [08:06<11:20, 97.24s/it]
|
|
||||||
50%|█████ | 6/12 [09:43<09:43, 97.24s/it]
|
|
||||||
58%|█████▊ | 7/12 [11:20<08:06, 97.27s/it]
|
|
||||||
67%|██████▋ | 8/12 [12:58<06:29, 97.36s/it]
|
|
||||||
75%|███████▌ | 9/12 [14:36<04:52, 97.49s/it]
|
|
||||||
83%|████████▎ | 10/12 [16:13<03:15, 97.52s/it]
|
|
||||||
92%|█████████▏| 11/12 [17:51<01:37, 97.47s/it]
|
|
||||||
100%|██████████| 12/12 [19:28<00:00, 97.35s/it]
|
|
||||||
100%|██████████| 12/12 [19:28<00:00, 97.35s/it]
|
|
||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
|
||||||
>>> Step 1: generating actions ...
|
|
||||||
>>> Step 1: interacting with world model ...
|
|
||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
|
||||||
>>> Step 2: generating actions ...
|
|
||||||
>>> Step 2: interacting with world model ...
|
|
||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
|
||||||
>>> Step 3: generating actions ...
|
|
||||||
>>> Step 3: interacting with world model ...
|
|
||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
|
||||||
>>> Step 4: generating actions ...
|
|
||||||
>>> Step 4: interacting with world model ...
|
|
||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
|
||||||
>>> Step 5: generating actions ...
|
|
||||||
>>> Step 5: interacting with world model ...
|
|
||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
|
||||||
>>> Step 6: generating actions ...
|
|
||||||
>>> Step 6: interacting with world model ...
|
|
||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
|
||||||
>>> Step 7: generating actions ...
|
|
||||||
>>> Step 7: interacting with world model ...
|
|
||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
|
||||||
>>> Step 8: generating actions ...
|
|
||||||
>>> Step 8: interacting with world model ...
|
|
||||||
114
run_all_case.sh
Normal file
114
run_all_case.sh
Normal file
@@ -0,0 +1,114 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# 自动执行所有场景的所有case
|
||||||
|
# 总共5个场景,每个场景4个case,共20个case
|
||||||
|
# 设置环境变量(离线模式)
|
||||||
|
export HF_HUB_OFFLINE=1
|
||||||
|
export TRANSFORMERS_OFFLINE=1
|
||||||
|
|
||||||
|
# 颜色定义
|
||||||
|
RED='\033[0;31m'
|
||||||
|
GREEN='\033[0;32m'
|
||||||
|
YELLOW='\033[1;33m'
|
||||||
|
BLUE='\033[0;34m'
|
||||||
|
NC='\033[0m' # No Color
|
||||||
|
|
||||||
|
# 定义所有场景
|
||||||
|
SCENARIOS=(
|
||||||
|
"unitree_g1_pack_camera"
|
||||||
|
"unitree_z1_dual_arm_cleanup_pencils"
|
||||||
|
"unitree_z1_dual_arm_stackbox"
|
||||||
|
"unitree_z1_dual_arm_stackbox_v2"
|
||||||
|
"unitree_z1_stackbox"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 定义case数量
|
||||||
|
CASES=(1 2 3 4)
|
||||||
|
|
||||||
|
# 记录开始时间
|
||||||
|
START_TIME=$(date +%s)
|
||||||
|
LOG_FILE="run_all_cases_$(date +%Y%m%d_%H%M%S).log"
|
||||||
|
|
||||||
|
echo -e "${BLUE}========================================${NC}"
|
||||||
|
echo -e "${BLUE}开始执行所有场景的case${NC}"
|
||||||
|
echo -e "${BLUE}总共: ${#SCENARIOS[@]} 个场景 x ${#CASES[@]} 个case = $((${#SCENARIOS[@]} * ${#CASES[@]})) 个任务${NC}"
|
||||||
|
echo -e "${BLUE}日志文件: ${LOG_FILE}${NC}"
|
||||||
|
echo -e "${BLUE}========================================${NC}"
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
# 初始化计数器
|
||||||
|
TOTAL_CASES=$((${#SCENARIOS[@]} * ${#CASES[@]}))
|
||||||
|
CURRENT_CASE=0
|
||||||
|
SUCCESS_COUNT=0
|
||||||
|
FAIL_COUNT=0
|
||||||
|
|
||||||
|
# 记录失败的case
|
||||||
|
declare -a FAILED_CASES
|
||||||
|
|
||||||
|
# 遍历所有场景
|
||||||
|
for scenario in "${SCENARIOS[@]}"; do
|
||||||
|
echo -e "${YELLOW}>>> 场景: ${scenario}${NC}"
|
||||||
|
|
||||||
|
# 遍历所有case
|
||||||
|
for case_num in "${CASES[@]}"; do
|
||||||
|
CURRENT_CASE=$((CURRENT_CASE + 1))
|
||||||
|
case_dir="${scenario}/case${case_num}"
|
||||||
|
script_path="${case_dir}/run_world_model_interaction.sh"
|
||||||
|
|
||||||
|
echo -e "${BLUE}[${CURRENT_CASE}/${TOTAL_CASES}] 执行: ${case_dir}${NC}"
|
||||||
|
|
||||||
|
# 检查脚本是否存在
|
||||||
|
if [ ! -f "${script_path}" ]; then
|
||||||
|
echo -e "${RED}错误: 脚本不存在 ${script_path}${NC}"
|
||||||
|
FAIL_COUNT=$((FAIL_COUNT + 1))
|
||||||
|
FAILED_CASES+=("${case_dir} (脚本不存在)")
|
||||||
|
continue
|
||||||
|
fi
|
||||||
|
|
||||||
|
# 执行脚本
|
||||||
|
echo "开始时间: $(date '+%Y-%m-%d %H:%M:%S')"
|
||||||
|
|
||||||
|
if bash "${script_path}" >> "${LOG_FILE}" 2>&1; then
|
||||||
|
echo -e "${GREEN}✓ 成功: ${case_dir}${NC}"
|
||||||
|
SUCCESS_COUNT=$((SUCCESS_COUNT + 1))
|
||||||
|
else
|
||||||
|
echo -e "${RED}✗ 失败: ${case_dir}${NC}"
|
||||||
|
FAIL_COUNT=$((FAIL_COUNT + 1))
|
||||||
|
FAILED_CASES+=("${case_dir}")
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "结束时间: $(date '+%Y-%m-%d %H:%M:%S')"
|
||||||
|
echo ""
|
||||||
|
done
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
done
|
||||||
|
|
||||||
|
# 计算总耗时
|
||||||
|
END_TIME=$(date +%s)
|
||||||
|
DURATION=$((END_TIME - START_TIME))
|
||||||
|
HOURS=$((DURATION / 3600))
|
||||||
|
MINUTES=$(((DURATION % 3600) / 60))
|
||||||
|
SECONDS=$((DURATION % 60))
|
||||||
|
|
||||||
|
# 输出总结
|
||||||
|
echo -e "${BLUE}========================================${NC}"
|
||||||
|
echo -e "${BLUE}执行完成!${NC}"
|
||||||
|
echo -e "${BLUE}========================================${NC}"
|
||||||
|
echo -e "总任务数: ${TOTAL_CASES}"
|
||||||
|
echo -e "${GREEN}成功: ${SUCCESS_COUNT}${NC}"
|
||||||
|
echo -e "${RED}失败: ${FAIL_COUNT}${NC}"
|
||||||
|
echo -e "总耗时: ${HOURS}小时 ${MINUTES}分钟 ${SECONDS}秒"
|
||||||
|
echo -e "详细日志: ${LOG_FILE}"
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
# 如果有失败的case,列出来
|
||||||
|
if [ ${FAIL_COUNT} -gt 0 ]; then
|
||||||
|
echo -e "${RED}失败的case列表:${NC}"
|
||||||
|
for failed_case in "${FAILED_CASES[@]}"; do
|
||||||
|
echo -e "${RED} - ${failed_case}${NC}"
|
||||||
|
done
|
||||||
|
echo ""
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo -e "${BLUE}========================================${NC}"
|
||||||
2328
run_all_cases_20260211_135725.log
Normal file
2328
run_all_cases_20260211_135725.log
Normal file
File diff suppressed because it is too large
Load Diff
37
run_all_cases_20260211_173422.log
Normal file
37
run_all_cases_20260211_173422.log
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
2026-02-11 17:34:29.188470: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
||||||
|
2026-02-11 17:34:29.238296: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
||||||
|
2026-02-11 17:34:29.238342: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
||||||
|
2026-02-11 17:34:29.239649: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
||||||
|
2026-02-11 17:34:29.247152: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
||||||
|
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
||||||
|
2026-02-11 17:34:30.172640: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
||||||
|
Global seed set to 123
|
||||||
|
>>> Loading prepared model from ckpts/unifolm_wma_dual.ckpt.prepared.pt ...
|
||||||
|
>>> Prepared model loaded.
|
||||||
|
>>> Diffusion backbone (model.model) converted to FP16.
|
||||||
|
>>> Projectors (image_proj_model, state_projector, action_projector) converted to FP16.
|
||||||
|
>>> Encoders (cond_stage_model, embedder) converted to FP16.
|
||||||
|
INFO:root:***** Configing Data *****
|
||||||
|
>>> unitree_z1_stackbox: 1 data samples loaded.
|
||||||
|
>>> unitree_z1_stackbox: data stats loaded.
|
||||||
|
>>> unitree_z1_stackbox: normalizer initiated.
|
||||||
|
>>> unitree_z1_dual_arm_stackbox: 1 data samples loaded.
|
||||||
|
>>> unitree_z1_dual_arm_stackbox: data stats loaded.
|
||||||
|
>>> unitree_z1_dual_arm_stackbox: normalizer initiated.
|
||||||
|
>>> unitree_z1_dual_arm_stackbox_v2: 1 data samples loaded.
|
||||||
|
>>> unitree_z1_dual_arm_stackbox_v2: data stats loaded.
|
||||||
|
>>> unitree_z1_dual_arm_stackbox_v2: normalizer initiated.
|
||||||
|
>>> unitree_z1_dual_arm_cleanup_pencils: 1 data samples loaded.
|
||||||
|
>>> unitree_z1_dual_arm_cleanup_pencils: data stats loaded.
|
||||||
|
>>> unitree_z1_dual_arm_cleanup_pencils: normalizer initiated.
|
||||||
|
>>> unitree_g1_pack_camera: 1 data samples loaded.
|
||||||
|
>>> unitree_g1_pack_camera: data stats loaded.
|
||||||
|
>>> unitree_g1_pack_camera: normalizer initiated.
|
||||||
|
>>> Dataset is successfully loaded ...
|
||||||
|
✓ KV fused: 66 attention layers
|
||||||
|
>>> Generate 16 frames under each generation ...
|
||||||
|
DEBUG:h5py._conv:Creating converter from 3 to 5
|
||||||
|
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
|
||||||
|
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
|
||||||
|
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
||||||
|
|
||||||
2388
run_all_cases_20260211_173635.log
Normal file
2388
run_all_cases_20260211_173635.log
Normal file
File diff suppressed because it is too large
Load Diff
0
run_all_cases_20260211_181733.log
Normal file
0
run_all_cases_20260211_181733.log
Normal file
1408
run_all_cases_20260219_185527.log
Normal file
1408
run_all_cases_20260219_185527.log
Normal file
File diff suppressed because it is too large
Load Diff
61
run_all_psnr.sh
Executable file
61
run_all_psnr.sh
Executable file
@@ -0,0 +1,61 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
set -e
|
||||||
|
|
||||||
|
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
|
||||||
|
cd "$SCRIPT_DIR"
|
||||||
|
|
||||||
|
SCENARIOS=(
|
||||||
|
unitree_g1_pack_camera
|
||||||
|
unitree_z1_dual_arm_cleanup_pencils
|
||||||
|
unitree_z1_dual_arm_stackbox
|
||||||
|
unitree_z1_dual_arm_stackbox_v2
|
||||||
|
unitree_z1_stackbox
|
||||||
|
)
|
||||||
|
|
||||||
|
CASES=(case1 case2 case3 case4)
|
||||||
|
|
||||||
|
total=0
|
||||||
|
success=0
|
||||||
|
fail=0
|
||||||
|
|
||||||
|
for scenario in "${SCENARIOS[@]}"; do
|
||||||
|
for case in "${CASES[@]}"; do
|
||||||
|
case_dir="${scenario}/${case}"
|
||||||
|
gt_video="${case_dir}/${scenario}_${case}.mp4"
|
||||||
|
pred_video=$(ls "${case_dir}"/output/inference/*_full_fs*.mp4 2>/dev/null | head -1)
|
||||||
|
output_file="${case_dir}/psnr_result.json"
|
||||||
|
|
||||||
|
total=$((total + 1))
|
||||||
|
echo "=========================================="
|
||||||
|
echo "[${total}/20] ${case_dir}"
|
||||||
|
|
||||||
|
if [ ! -f "$gt_video" ]; then
|
||||||
|
echo " SKIP: GT video not found: $gt_video"
|
||||||
|
fail=$((fail + 1))
|
||||||
|
continue
|
||||||
|
fi
|
||||||
|
if [ -z "$pred_video" ]; then
|
||||||
|
echo " SKIP: pred video not found in ${case_dir}/output/inference/"
|
||||||
|
fail=$((fail + 1))
|
||||||
|
continue
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo " GT: $gt_video"
|
||||||
|
echo " Pred: $pred_video"
|
||||||
|
echo " Out: $output_file"
|
||||||
|
|
||||||
|
if python3 psnr_score_for_challenge.py \
|
||||||
|
--gt_video "$gt_video" \
|
||||||
|
--pred_video "$pred_video" \
|
||||||
|
--output_file "$output_file"; then
|
||||||
|
success=$((success + 1))
|
||||||
|
echo " DONE"
|
||||||
|
else
|
||||||
|
fail=$((fail + 1))
|
||||||
|
echo " FAILED"
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
done
|
||||||
|
|
||||||
|
echo "=========================================="
|
||||||
|
echo "Finished: ${success} success, ${fail} fail, ${total} total"
|
||||||
@@ -16,6 +16,9 @@ from collections import OrderedDict
|
|||||||
from unifolm_wma.models.samplers.ddim import DDIMSampler
|
from unifolm_wma.models.samplers.ddim import DDIMSampler
|
||||||
from unifolm_wma.utils.utils import instantiate_from_config
|
from unifolm_wma.utils.utils import instantiate_from_config
|
||||||
|
|
||||||
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
torch.backends.cudnn.allow_tf32 = True
|
||||||
|
|
||||||
|
|
||||||
def get_filelist(data_dir: str, postfixes: list[str]) -> list[str]:
|
def get_filelist(data_dir: str, postfixes: list[str]) -> list[str]:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -1,975 +0,0 @@
|
|||||||
"""
|
|
||||||
Profile the full iteration loop of world model interaction.
|
|
||||||
|
|
||||||
Three layers of profiling:
|
|
||||||
Layer 1: Iteration-level wall-clock breakdown (CUDA events)
|
|
||||||
Layer 2: GPU timeline trace (torch.profiler → Chrome trace)
|
|
||||||
Layer 3: A/B comparison (standardized CSV output)
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
# Layer 1 only (fast, default):
|
|
||||||
TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 CUDA_VISIBLE_DEVICES=0 \
|
|
||||||
python scripts/evaluation/profile_iteration.py \
|
|
||||||
--ckpt_path ckpts/unifolm_wma_dual_mix_bf16.ckpt \
|
|
||||||
--config configs/inference/world_model_interaction.yaml \
|
|
||||||
--prompt_dir unitree_z1_dual_arm_cleanup_pencils/case1/world_model_interaction_prompts \
|
|
||||||
--dataset unitree_z1_dual_arm_cleanup_pencils \
|
|
||||||
--frame_stride 4 --n_iter 5
|
|
||||||
|
|
||||||
# Layer 1 + Layer 2 (GPU trace):
|
|
||||||
... --trace --trace_dir ./profile_traces
|
|
||||||
|
|
||||||
# Layer 3 (A/B comparison): run twice, diff the CSVs
|
|
||||||
... --csv baseline.csv
|
|
||||||
... --csv optimized.csv
|
|
||||||
python scripts/evaluation/profile_iteration.py --compare baseline.csv optimized.csv
|
|
||||||
"""
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import csv
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import time
|
|
||||||
from collections import defaultdict, deque
|
|
||||||
from contextlib import nullcontext
|
|
||||||
|
|
||||||
import h5py
|
|
||||||
import numpy as np
|
|
||||||
import pandas as pd
|
|
||||||
import torch
|
|
||||||
import torchvision
|
|
||||||
from einops import rearrange, repeat
|
|
||||||
from omegaconf import OmegaConf
|
|
||||||
from PIL import Image
|
|
||||||
from pytorch_lightning import seed_everything
|
|
||||||
from torch import Tensor
|
|
||||||
|
|
||||||
from unifolm_wma.models.samplers.ddim import DDIMSampler
|
|
||||||
from unifolm_wma.utils.utils import instantiate_from_config
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
|
|
||||||
# ──────────────────────────────────────────────────────────────────────
|
|
||||||
# Constants
|
|
||||||
# ──────────────────────────────────────────────────────────────────────
|
|
||||||
STAGE_NAMES = [
|
|
||||||
"stack_to_device_1",
|
|
||||||
"synth_policy",
|
|
||||||
"update_action_queue",
|
|
||||||
"stack_to_device_2",
|
|
||||||
"synth_world_model",
|
|
||||||
"update_obs_queue",
|
|
||||||
"tensorboard_log",
|
|
||||||
"save_results",
|
|
||||||
"cpu_transfer",
|
|
||||||
"itr_total",
|
|
||||||
]
|
|
||||||
|
|
||||||
# Sub-stages inside image_guided_synthesis_sim_mode
|
|
||||||
SYNTH_SUB_STAGES = [
|
|
||||||
"ddim_sampler_init",
|
|
||||||
"image_embedding",
|
|
||||||
"vae_encode",
|
|
||||||
"text_conditioning",
|
|
||||||
"projectors",
|
|
||||||
"cond_assembly",
|
|
||||||
"ddim_sampling",
|
|
||||||
"vae_decode",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
# ──────────────────────────────────────────────────────────────────────
|
|
||||||
# CudaTimer — GPU-precise timing via CUDA events
|
|
||||||
# ──────────────────────────────────────────────────────────────────────
|
|
||||||
class CudaTimer:
|
|
||||||
"""Context manager that records GPU time between enter/exit using CUDA events."""
|
|
||||||
|
|
||||||
def __init__(self, name, records):
|
|
||||||
self.name = name
|
|
||||||
self.records = records
|
|
||||||
|
|
||||||
def __enter__(self):
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
self._start = torch.cuda.Event(enable_timing=True)
|
|
||||||
self._end = torch.cuda.Event(enable_timing=True)
|
|
||||||
self._start.record()
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __exit__(self, *args):
|
|
||||||
self._end.record()
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
elapsed_ms = self._start.elapsed_time(self._end)
|
|
||||||
self.records[self.name].append(elapsed_ms)
|
|
||||||
|
|
||||||
|
|
||||||
class WallTimer:
|
|
||||||
"""Context manager that records CPU wall-clock time (for pure-CPU stages)."""
|
|
||||||
|
|
||||||
def __init__(self, name, records):
|
|
||||||
self.name = name
|
|
||||||
self.records = records
|
|
||||||
|
|
||||||
def __enter__(self):
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
self._t0 = time.perf_counter()
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __exit__(self, *args):
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
elapsed_ms = (time.perf_counter() - self._t0) * 1000.0
|
|
||||||
self.records[self.name].append(elapsed_ms)
|
|
||||||
|
|
||||||
|
|
||||||
# ──────────────────────────────────────────────────────────────────────
|
|
||||||
# Model loading (reused from world_model_interaction.py)
|
|
||||||
# ──────────────────────────────────────────────────────────────────────
|
|
||||||
def patch_norm_bypass_autocast():
|
|
||||||
def _group_norm_forward(self, x):
|
|
||||||
with torch.amp.autocast('cuda', enabled=False):
|
|
||||||
return F.group_norm(
|
|
||||||
x, self.num_groups,
|
|
||||||
self.weight.to(x.dtype) if self.weight is not None else None,
|
|
||||||
self.bias.to(x.dtype) if self.bias is not None else None,
|
|
||||||
self.eps)
|
|
||||||
|
|
||||||
def _layer_norm_forward(self, x):
|
|
||||||
with torch.amp.autocast('cuda', enabled=False):
|
|
||||||
return F.layer_norm(
|
|
||||||
x, self.normalized_shape,
|
|
||||||
self.weight.to(x.dtype) if self.weight is not None else None,
|
|
||||||
self.bias.to(x.dtype) if self.bias is not None else None,
|
|
||||||
self.eps)
|
|
||||||
|
|
||||||
torch.nn.GroupNorm.forward = _group_norm_forward
|
|
||||||
torch.nn.LayerNorm.forward = _layer_norm_forward
|
|
||||||
|
|
||||||
|
|
||||||
def apply_torch_compile(model, hot_indices=(5, 8, 9)):
|
|
||||||
from unifolm_wma.modules.networks.wma_model import ResBlock
|
|
||||||
unet = model.model.diffusion_model
|
|
||||||
compiled = 0
|
|
||||||
for idx in hot_indices:
|
|
||||||
block = unet.output_blocks[idx]
|
|
||||||
for layer in block:
|
|
||||||
if isinstance(layer, ResBlock):
|
|
||||||
layer._forward = torch.compile(layer._forward, mode="default")
|
|
||||||
compiled += 1
|
|
||||||
print(f" torch.compile: {compiled} ResBlocks in output_blocks{list(hot_indices)}")
|
|
||||||
|
|
||||||
|
|
||||||
def load_model(args):
|
|
||||||
config = OmegaConf.load(args.config)
|
|
||||||
config['model']['params']['wma_config']['params']['use_checkpoint'] = False
|
|
||||||
model = instantiate_from_config(config.model)
|
|
||||||
model.perframe_ae = args.perframe_ae
|
|
||||||
|
|
||||||
from collections import OrderedDict
|
|
||||||
state_dict = torch.load(args.ckpt_path, map_location="cpu")
|
|
||||||
if "state_dict" in state_dict:
|
|
||||||
state_dict = state_dict["state_dict"]
|
|
||||||
try:
|
|
||||||
model.load_state_dict(state_dict, strict=True)
|
|
||||||
except Exception:
|
|
||||||
new_sd = OrderedDict()
|
|
||||||
for k, v in state_dict.items():
|
|
||||||
new_sd[k] = v
|
|
||||||
for k in list(new_sd.keys()):
|
|
||||||
if "framestride_embed" in k:
|
|
||||||
new_sd[k.replace("framestride_embed", "fps_embedding")] = new_sd.pop(k)
|
|
||||||
model.load_state_dict(new_sd, strict=True)
|
|
||||||
|
|
||||||
model.eval()
|
|
||||||
|
|
||||||
# Apply precision: bf16 diffusion + encoders + projectors, fp32/bf16 VAE
|
|
||||||
model.model.to(torch.bfloat16)
|
|
||||||
model.diffusion_autocast_dtype = torch.bfloat16
|
|
||||||
model.embedder.to(torch.bfloat16)
|
|
||||||
model.image_proj_model.to(torch.bfloat16)
|
|
||||||
model.encoder_autocast_dtype = None
|
|
||||||
model.state_projector.to(torch.bfloat16)
|
|
||||||
model.action_projector.to(torch.bfloat16)
|
|
||||||
model.projector_autocast_dtype = None
|
|
||||||
if args.vae_dtype == "bf16":
|
|
||||||
model.first_stage_model.to(torch.bfloat16)
|
|
||||||
|
|
||||||
# Compile hot ResBlocks
|
|
||||||
apply_torch_compile(model)
|
|
||||||
model = model.cuda()
|
|
||||||
print(">>> Model loaded and ready.")
|
|
||||||
return model, config
|
|
||||||
|
|
||||||
|
|
||||||
# ──────────────────────────────────────────────────────────────────────
|
|
||||||
# Data preparation (reused from world_model_interaction.py)
|
|
||||||
# ──────────────────────────────────────────────────────────────────────
|
|
||||||
def get_init_frame_path(data_dir, sample):
|
|
||||||
rel = os.path.join(sample['data_dir'], str(sample['videoid']) + '.png')
|
|
||||||
return os.path.join(data_dir, 'images', rel)
|
|
||||||
|
|
||||||
|
|
||||||
def get_transition_path(data_dir, sample):
|
|
||||||
rel = os.path.join(sample['data_dir'], str(sample['videoid']) + '.h5')
|
|
||||||
return os.path.join(data_dir, 'transitions', rel)
|
|
||||||
|
|
||||||
|
|
||||||
def prepare_init_input(start_idx, init_frame_path, transition_dict,
|
|
||||||
frame_stride, wma_data, video_length=16, n_obs_steps=2):
|
|
||||||
indices = [start_idx + frame_stride * i for i in range(video_length)]
|
|
||||||
init_frame = Image.open(init_frame_path).convert('RGB')
|
|
||||||
init_frame = torch.tensor(np.array(init_frame)).unsqueeze(0).permute(3, 0, 1, 2).float()
|
|
||||||
|
|
||||||
if start_idx < n_obs_steps - 1:
|
|
||||||
state_indices = list(range(0, start_idx + 1))
|
|
||||||
states = transition_dict['observation.state'][state_indices, :]
|
|
||||||
num_padding = n_obs_steps - 1 - start_idx
|
|
||||||
padding = states[0:1, :].repeat(num_padding, 1)
|
|
||||||
states = torch.cat((padding, states), dim=0)
|
|
||||||
else:
|
|
||||||
state_indices = list(range(start_idx - n_obs_steps + 1, start_idx + 1))
|
|
||||||
states = transition_dict['observation.state'][state_indices, :]
|
|
||||||
|
|
||||||
actions = transition_dict['action'][indices, :]
|
|
||||||
ori_state_dim = states.shape[-1]
|
|
||||||
ori_action_dim = actions.shape[-1]
|
|
||||||
|
|
||||||
frames_action_state_dict = {
|
|
||||||
'action': actions,
|
|
||||||
'observation.state': states,
|
|
||||||
}
|
|
||||||
frames_action_state_dict = wma_data.normalizer(frames_action_state_dict)
|
|
||||||
frames_action_state_dict = wma_data.get_uni_vec(
|
|
||||||
frames_action_state_dict,
|
|
||||||
transition_dict['action_type'],
|
|
||||||
transition_dict['state_type'],
|
|
||||||
)
|
|
||||||
|
|
||||||
if wma_data.spatial_transform is not None:
|
|
||||||
init_frame = wma_data.spatial_transform(init_frame)
|
|
||||||
init_frame = (init_frame / 255 - 0.5) * 2
|
|
||||||
|
|
||||||
data = {'observation.image': init_frame}
|
|
||||||
data.update(frames_action_state_dict)
|
|
||||||
return data, ori_state_dim, ori_action_dim
|
|
||||||
|
|
||||||
|
|
||||||
def populate_queues(queues, batch):
|
|
||||||
for key in batch:
|
|
||||||
if key not in queues:
|
|
||||||
continue
|
|
||||||
if len(queues[key]) != queues[key].maxlen:
|
|
||||||
while len(queues[key]) != queues[key].maxlen:
|
|
||||||
queues[key].append(batch[key])
|
|
||||||
else:
|
|
||||||
queues[key].append(batch[key])
|
|
||||||
return queues
|
|
||||||
|
|
||||||
|
|
||||||
# ──────────────────────────────────────────────────────────────────────
|
|
||||||
# Instrumented image_guided_synthesis_sim_mode with sub-stage timing
|
|
||||||
# ──────────────────────────────────────────────────────────────────────
|
|
||||||
def get_latent_z(model, videos):
|
|
||||||
b, c, t, h, w = videos.shape
|
|
||||||
x = rearrange(videos, 'b c t h w -> (b t) c h w')
|
|
||||||
vae_dtype = next(model.first_stage_model.parameters()).dtype
|
|
||||||
x = x.to(dtype=vae_dtype)
|
|
||||||
z = model.encode_first_stage(x)
|
|
||||||
z = rearrange(z, '(b t) c h w -> b c t h w', b=b, t=t)
|
|
||||||
return z
|
|
||||||
|
|
||||||
|
|
||||||
def save_results(video, filename, fps=8):
|
|
||||||
video = video.detach().cpu()
|
|
||||||
video = torch.clamp(video.float(), -1., 1.)
|
|
||||||
n = video.shape[0]
|
|
||||||
video = video.permute(2, 0, 1, 3, 4)
|
|
||||||
frame_grids = [
|
|
||||||
torchvision.utils.make_grid(framesheet, nrow=int(n), padding=0)
|
|
||||||
for framesheet in video
|
|
||||||
]
|
|
||||||
grid = torch.stack(frame_grids, dim=0)
|
|
||||||
grid = (grid + 1.0) / 2.0
|
|
||||||
grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1)
|
|
||||||
torchvision.io.write_video(filename, grid, fps=fps,
|
|
||||||
video_codec='h264', options={'crf': '10'})
|
|
||||||
|
|
||||||
|
|
||||||
def profiled_synthesis(model, prompts, observation, noise_shape,
|
|
||||||
ddim_steps, ddim_eta, unconditional_guidance_scale,
|
|
||||||
fs, text_input, timestep_spacing, guidance_rescale,
|
|
||||||
sim_mode, decode_video, records, prefix):
|
|
||||||
"""image_guided_synthesis_sim_mode with per-sub-stage CUDA event timing.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
prefix: "policy" or "wm" — prepended to sub-stage names in records.
|
|
||||||
"""
|
|
||||||
b, _, t, _, _ = noise_shape
|
|
||||||
batch_size = noise_shape[0]
|
|
||||||
device = next(model.parameters()).device
|
|
||||||
|
|
||||||
# --- sub-stage: ddim_sampler_init ---
|
|
||||||
with CudaTimer(f"{prefix}/ddim_sampler_init", records):
|
|
||||||
ddim_sampler = DDIMSampler(model)
|
|
||||||
fs_t = torch.tensor([fs] * batch_size, dtype=torch.long, device=device)
|
|
||||||
|
|
||||||
# --- sub-stage: image_embedding ---
|
|
||||||
with CudaTimer(f"{prefix}/image_embedding", records):
|
|
||||||
model_dtype = next(model.embedder.parameters()).dtype
|
|
||||||
img = observation['observation.images.top'].permute(0, 2, 1, 3, 4)
|
|
||||||
cond_img = rearrange(img, 'b o c h w -> (b o) c h w')[-1:].to(dtype=model_dtype)
|
|
||||||
cond_img_emb = model.embedder(cond_img)
|
|
||||||
cond_img_emb = model.image_proj_model(cond_img_emb)
|
|
||||||
|
|
||||||
# --- sub-stage: vae_encode ---
|
|
||||||
with CudaTimer(f"{prefix}/vae_encode", records):
|
|
||||||
if model.model.conditioning_key == 'hybrid':
|
|
||||||
z = get_latent_z(model, img.permute(0, 2, 1, 3, 4))
|
|
||||||
img_cat_cond = z[:, :, -1:, :, :]
|
|
||||||
img_cat_cond = repeat(img_cat_cond,
|
|
||||||
'b c t h w -> b c (repeat t) h w',
|
|
||||||
repeat=noise_shape[2])
|
|
||||||
cond = {"c_concat": [img_cat_cond]}
|
|
||||||
|
|
||||||
# --- sub-stage: text_conditioning ---
|
|
||||||
with CudaTimer(f"{prefix}/text_conditioning", records):
|
|
||||||
if not text_input:
|
|
||||||
prompts_use = [""] * batch_size
|
|
||||||
else:
|
|
||||||
prompts_use = prompts if isinstance(prompts, list) else [prompts] * batch_size
|
|
||||||
cond_ins_emb = model.get_learned_conditioning(prompts_use)
|
|
||||||
|
|
||||||
# --- sub-stage: projectors ---
|
|
||||||
with CudaTimer(f"{prefix}/projectors", records):
|
|
||||||
projector_dtype = next(model.state_projector.parameters()).dtype
|
|
||||||
cond_state_emb = model.state_projector(
|
|
||||||
observation['observation.state'].to(dtype=projector_dtype))
|
|
||||||
cond_state_emb = cond_state_emb + model.agent_state_pos_emb
|
|
||||||
cond_action_emb = model.action_projector(
|
|
||||||
observation['action'].to(dtype=projector_dtype))
|
|
||||||
cond_action_emb = cond_action_emb + model.agent_action_pos_emb
|
|
||||||
if not sim_mode:
|
|
||||||
cond_action_emb = torch.zeros_like(cond_action_emb)
|
|
||||||
|
|
||||||
# --- sub-stage: cond_assembly ---
|
|
||||||
with CudaTimer(f"{prefix}/cond_assembly", records):
|
|
||||||
cond["c_crossattn"] = [
|
|
||||||
torch.cat([cond_state_emb, cond_action_emb, cond_ins_emb, cond_img_emb], dim=1)
|
|
||||||
]
|
|
||||||
cond["c_crossattn_action"] = [
|
|
||||||
observation['observation.images.top'][:, :, -model.n_obs_steps_acting:],
|
|
||||||
observation['observation.state'][:, -model.n_obs_steps_acting:],
|
|
||||||
sim_mode,
|
|
||||||
False,
|
|
||||||
]
|
|
||||||
|
|
||||||
# --- sub-stage: ddim_sampling ---
|
|
||||||
autocast_dtype = getattr(model, 'diffusion_autocast_dtype', None)
|
|
||||||
if autocast_dtype is not None and device.type == 'cuda':
|
|
||||||
autocast_ctx = torch.autocast('cuda', dtype=autocast_dtype)
|
|
||||||
else:
|
|
||||||
autocast_ctx = nullcontext()
|
|
||||||
|
|
||||||
with CudaTimer(f"{prefix}/ddim_sampling", records):
|
|
||||||
with autocast_ctx:
|
|
||||||
samples, actions, states, _ = ddim_sampler.sample(
|
|
||||||
S=ddim_steps,
|
|
||||||
conditioning=cond,
|
|
||||||
batch_size=batch_size,
|
|
||||||
shape=noise_shape[1:],
|
|
||||||
verbose=False,
|
|
||||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
|
||||||
unconditional_conditioning=None,
|
|
||||||
eta=ddim_eta,
|
|
||||||
cfg_img=None,
|
|
||||||
mask=None,
|
|
||||||
x0=None,
|
|
||||||
fs=fs_t,
|
|
||||||
timestep_spacing=timestep_spacing,
|
|
||||||
guidance_rescale=guidance_rescale,
|
|
||||||
unconditional_conditioning_img_nonetext=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
# --- sub-stage: vae_decode ---
|
|
||||||
batch_variants = None
|
|
||||||
if decode_video:
|
|
||||||
with CudaTimer(f"{prefix}/vae_decode", records):
|
|
||||||
batch_variants = model.decode_first_stage(samples)
|
|
||||||
else:
|
|
||||||
records[f"{prefix}/vae_decode"].append(0.0)
|
|
||||||
|
|
||||||
return batch_variants, actions, states
|
|
||||||
|
|
||||||
|
|
||||||
# ──────────────────────────────────────────────────────────────────────
|
|
||||||
# Instrumented iteration loop
|
|
||||||
# ──────────────────────────────────────────────────────────────────────
|
|
||||||
def run_profiled_iterations(model, args, config, noise_shape, device):
|
|
||||||
"""Run the full iteration loop with per-stage timing.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
all_records: list of dicts, one per itr, {stage_name: ms}
|
|
||||||
"""
|
|
||||||
# Load data
|
|
||||||
csv_path = os.path.join(args.prompt_dir, f"{args.dataset}.csv")
|
|
||||||
df = pd.read_csv(csv_path)
|
|
||||||
sample = df.iloc[0]
|
|
||||||
|
|
||||||
data_module = instantiate_from_config(config.data)
|
|
||||||
data_module.setup()
|
|
||||||
|
|
||||||
init_frame_path = get_init_frame_path(args.prompt_dir, sample)
|
|
||||||
ori_fps = float(sample['fps'])
|
|
||||||
fs = args.frame_stride
|
|
||||||
model_input_fs = ori_fps // fs
|
|
||||||
|
|
||||||
transition_path = get_transition_path(args.prompt_dir, sample)
|
|
||||||
with h5py.File(transition_path, 'r') as h5f:
|
|
||||||
transition_dict = {}
|
|
||||||
for key in h5f.keys():
|
|
||||||
transition_dict[key] = torch.tensor(h5f[key][()])
|
|
||||||
for key in h5f.attrs.keys():
|
|
||||||
transition_dict[key] = h5f.attrs[key]
|
|
||||||
|
|
||||||
# Prepare initial observation
|
|
||||||
batch, ori_state_dim, ori_action_dim = prepare_init_input(
|
|
||||||
0, init_frame_path, transition_dict, fs,
|
|
||||||
data_module.test_datasets[args.dataset],
|
|
||||||
n_obs_steps=model.n_obs_steps_imagen)
|
|
||||||
|
|
||||||
observation = {
|
|
||||||
'observation.images.top':
|
|
||||||
batch['observation.image'].permute(1, 0, 2, 3)[-1].unsqueeze(0),
|
|
||||||
'observation.state':
|
|
||||||
batch['observation.state'][-1].unsqueeze(0),
|
|
||||||
'action':
|
|
||||||
torch.zeros_like(batch['action'][-1]).unsqueeze(0),
|
|
||||||
}
|
|
||||||
observation = {k: v.to(device, non_blocking=True) for k, v in observation.items()}
|
|
||||||
|
|
||||||
cond_obs_queues = {
|
|
||||||
"observation.images.top": deque(maxlen=model.n_obs_steps_imagen),
|
|
||||||
"observation.state": deque(maxlen=model.n_obs_steps_imagen),
|
|
||||||
"action": deque(maxlen=args.video_length),
|
|
||||||
}
|
|
||||||
cond_obs_queues = populate_queues(cond_obs_queues, observation)
|
|
||||||
|
|
||||||
# Temp dir for save_results profiling
|
|
||||||
tmp_dir = os.path.join(args.savedir, "profile_tmp")
|
|
||||||
os.makedirs(tmp_dir, exist_ok=True)
|
|
||||||
|
|
||||||
prompt_text = sample['instruction']
|
|
||||||
all_records = []
|
|
||||||
|
|
||||||
print(f">>> Running {args.n_iter} profiled iterations ...")
|
|
||||||
for itr in range(args.n_iter):
|
|
||||||
rec = defaultdict(list)
|
|
||||||
|
|
||||||
# ── itr_total start ──
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
itr_start = torch.cuda.Event(enable_timing=True)
|
|
||||||
itr_end = torch.cuda.Event(enable_timing=True)
|
|
||||||
itr_start.record()
|
|
||||||
|
|
||||||
# ① stack_to_device_1
|
|
||||||
with CudaTimer("stack_to_device_1", rec):
|
|
||||||
observation = {
|
|
||||||
'observation.images.top':
|
|
||||||
torch.stack(list(cond_obs_queues['observation.images.top']),
|
|
||||||
dim=1).permute(0, 2, 1, 3, 4),
|
|
||||||
'observation.state':
|
|
||||||
torch.stack(list(cond_obs_queues['observation.state']), dim=1),
|
|
||||||
'action':
|
|
||||||
torch.stack(list(cond_obs_queues['action']), dim=1),
|
|
||||||
}
|
|
||||||
observation = {k: v.to(device, non_blocking=True) for k, v in observation.items()}
|
|
||||||
|
|
||||||
# ② synth_policy
|
|
||||||
with CudaTimer("synth_policy", rec):
|
|
||||||
pred_videos_0, pred_actions, _ = profiled_synthesis(
|
|
||||||
model, prompt_text, observation, noise_shape,
|
|
||||||
ddim_steps=args.ddim_steps, ddim_eta=args.ddim_eta,
|
|
||||||
unconditional_guidance_scale=args.unconditional_guidance_scale,
|
|
||||||
fs=model_input_fs, text_input=True,
|
|
||||||
timestep_spacing=args.timestep_spacing,
|
|
||||||
guidance_rescale=args.guidance_rescale,
|
|
||||||
sim_mode=False,
|
|
||||||
decode_video=not args.fast_policy_no_decode,
|
|
||||||
records=rec, prefix="policy")
|
|
||||||
|
|
||||||
# ③ update_action_queue
|
|
||||||
with WallTimer("update_action_queue", rec):
|
|
||||||
for idx in range(len(pred_actions[0])):
|
|
||||||
obs_a = {'action': pred_actions[0][idx:idx + 1]}
|
|
||||||
obs_a['action'][:, ori_action_dim:] = 0.0
|
|
||||||
cond_obs_queues = populate_queues(cond_obs_queues, obs_a)
|
|
||||||
|
|
||||||
# ④ stack_to_device_2
|
|
||||||
with CudaTimer("stack_to_device_2", rec):
|
|
||||||
observation = {
|
|
||||||
'observation.images.top':
|
|
||||||
torch.stack(list(cond_obs_queues['observation.images.top']),
|
|
||||||
dim=1).permute(0, 2, 1, 3, 4),
|
|
||||||
'observation.state':
|
|
||||||
torch.stack(list(cond_obs_queues['observation.state']), dim=1),
|
|
||||||
'action':
|
|
||||||
torch.stack(list(cond_obs_queues['action']), dim=1),
|
|
||||||
}
|
|
||||||
observation = {k: v.to(device, non_blocking=True) for k, v in observation.items()}
|
|
||||||
|
|
||||||
# ⑤ synth_world_model
|
|
||||||
with CudaTimer("synth_world_model", rec):
|
|
||||||
pred_videos_1, _, pred_states = profiled_synthesis(
|
|
||||||
model, "", observation, noise_shape,
|
|
||||||
ddim_steps=args.ddim_steps, ddim_eta=args.ddim_eta,
|
|
||||||
unconditional_guidance_scale=args.unconditional_guidance_scale,
|
|
||||||
fs=model_input_fs, text_input=False,
|
|
||||||
timestep_spacing=args.timestep_spacing,
|
|
||||||
guidance_rescale=args.guidance_rescale,
|
|
||||||
sim_mode=True, decode_video=True,
|
|
||||||
records=rec, prefix="wm")
|
|
||||||
|
|
||||||
# ⑥ update_obs_queue
|
|
||||||
with WallTimer("update_obs_queue", rec):
|
|
||||||
for idx in range(args.exe_steps):
|
|
||||||
obs_u = {
|
|
||||||
'observation.images.top':
|
|
||||||
pred_videos_1[0][:, idx:idx + 1].permute(1, 0, 2, 3),
|
|
||||||
'observation.state':
|
|
||||||
pred_states[0][idx:idx + 1],
|
|
||||||
'action':
|
|
||||||
torch.zeros_like(pred_actions[0][-1:]),
|
|
||||||
}
|
|
||||||
obs_u['observation.state'][:, ori_state_dim:] = 0.0
|
|
||||||
cond_obs_queues = populate_queues(cond_obs_queues, obs_u)
|
|
||||||
|
|
||||||
# ⑦ tensorboard_log (simulate — no actual writer, measure make_grid cost)
|
|
||||||
with WallTimer("tensorboard_log", rec):
|
|
||||||
for vid in [pred_videos_0, pred_videos_1]:
|
|
||||||
if vid is not None and vid.dim() == 5:
|
|
||||||
v = vid.permute(2, 0, 1, 3, 4)
|
|
||||||
grids = [torchvision.utils.make_grid(f, nrow=1, padding=0) for f in v]
|
|
||||||
_ = torch.stack(grids, dim=0)
|
|
||||||
|
|
||||||
# ⑧ save_results
|
|
||||||
with WallTimer("save_results", rec):
|
|
||||||
if pred_videos_0 is not None:
|
|
||||||
save_results(pred_videos_0.cpu(),
|
|
||||||
os.path.join(tmp_dir, f"dm_{itr}.mp4"),
|
|
||||||
fps=args.save_fps)
|
|
||||||
save_results(pred_videos_1.cpu(),
|
|
||||||
os.path.join(tmp_dir, f"wm_{itr}.mp4"),
|
|
||||||
fps=args.save_fps)
|
|
||||||
|
|
||||||
# ⑨ cpu_transfer
|
|
||||||
with CudaTimer("cpu_transfer", rec):
|
|
||||||
_ = pred_videos_1[:, :, :args.exe_steps].cpu()
|
|
||||||
|
|
||||||
# ── itr_total end ──
|
|
||||||
itr_end.record()
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
itr_total_ms = itr_start.elapsed_time(itr_end)
|
|
||||||
rec["itr_total"].append(itr_total_ms)
|
|
||||||
|
|
||||||
# Flatten: each stage has exactly one entry per itr
|
|
||||||
itr_rec = {k: v[0] for k, v in rec.items()}
|
|
||||||
all_records.append(itr_rec)
|
|
||||||
|
|
||||||
# Print live progress
|
|
||||||
print(f" itr {itr}: {itr_total_ms:.0f} ms total | "
|
|
||||||
f"policy={itr_rec.get('synth_policy', 0):.0f} | "
|
|
||||||
f"wm={itr_rec.get('synth_world_model', 0):.0f} | "
|
|
||||||
f"save={itr_rec.get('save_results', 0):.0f} | "
|
|
||||||
f"tb={itr_rec.get('tensorboard_log', 0):.0f}")
|
|
||||||
|
|
||||||
return all_records
|
|
||||||
|
|
||||||
|
|
||||||
# ──────────────────────────────────────────────────────────────────────
|
|
||||||
# Layer 1: Console report
|
|
||||||
# ──────────────────────────────────────────────────────────────────────
|
|
||||||
def print_iteration_report(all_records, warmup=1):
|
|
||||||
"""Print a structured table of per-stage timing across iterations."""
|
|
||||||
if len(all_records) <= warmup:
|
|
||||||
records = all_records
|
|
||||||
else:
|
|
||||||
records = all_records[warmup:]
|
|
||||||
print(f"\n(Skipping first {warmup} itr(s) as warmup)\n")
|
|
||||||
|
|
||||||
# Collect all stage keys in a stable order
|
|
||||||
all_keys = []
|
|
||||||
seen = set()
|
|
||||||
for rec in records:
|
|
||||||
for k in rec:
|
|
||||||
if k not in seen:
|
|
||||||
all_keys.append(k)
|
|
||||||
seen.add(k)
|
|
||||||
|
|
||||||
# Separate top-level stages from sub-stages
|
|
||||||
top_keys = [k for k in all_keys if '/' not in k]
|
|
||||||
sub_keys = [k for k in all_keys if '/' in k]
|
|
||||||
|
|
||||||
def _print_table(keys, title):
|
|
||||||
if not keys:
|
|
||||||
return
|
|
||||||
print("=" * 82)
|
|
||||||
print(title)
|
|
||||||
print("=" * 82)
|
|
||||||
print(f"{'Stage':<35} {'Mean(ms)':>10} {'Std':>8} {'Min':>10} {'Max':>10} {'%':>7}")
|
|
||||||
print("-" * 82)
|
|
||||||
|
|
||||||
total_mean = np.mean([rec.get("itr_total", 0) for rec in records])
|
|
||||||
for k in keys:
|
|
||||||
vals = [rec.get(k, 0) for rec in records]
|
|
||||||
mean = np.mean(vals)
|
|
||||||
std = np.std(vals)
|
|
||||||
mn = np.min(vals)
|
|
||||||
mx = np.max(vals)
|
|
||||||
pct = mean / total_mean * 100 if total_mean > 0 else 0
|
|
||||||
print(f"{k:<35} {mean:>10.1f} {std:>8.1f} {mn:>10.1f} {mx:>10.1f} {pct:>6.1f}%")
|
|
||||||
print("-" * 82)
|
|
||||||
print()
|
|
||||||
|
|
||||||
_print_table(top_keys, "TABLE 1: ITERATION-LEVEL BREAKDOWN")
|
|
||||||
_print_table(sub_keys, "TABLE 2: SYNTHESIS SUB-STAGE BREAKDOWN")
|
|
||||||
|
|
||||||
|
|
||||||
# ──────────────────────────────────────────────────────────────────────
|
|
||||||
# Layer 3: CSV output for A/B comparison
|
|
||||||
# ──────────────────────────────────────────────────────────────────────
|
|
||||||
def write_csv(all_records, csv_path, warmup=1):
|
|
||||||
"""Write per-iteration timing to CSV for later comparison."""
|
|
||||||
records = all_records[warmup:] if len(all_records) > warmup else all_records
|
|
||||||
|
|
||||||
# Collect all keys
|
|
||||||
all_keys = []
|
|
||||||
seen = set()
|
|
||||||
for rec in records:
|
|
||||||
for k in rec:
|
|
||||||
if k not in seen:
|
|
||||||
all_keys.append(k)
|
|
||||||
seen.add(k)
|
|
||||||
|
|
||||||
with open(csv_path, 'w', newline='') as f:
|
|
||||||
writer = csv.DictWriter(f, fieldnames=['itr'] + all_keys)
|
|
||||||
writer.writeheader()
|
|
||||||
for i, rec in enumerate(records):
|
|
||||||
row = {'itr': i}
|
|
||||||
row.update({k: f"{rec.get(k, 0):.2f}" for k in all_keys})
|
|
||||||
writer.writerow(row)
|
|
||||||
|
|
||||||
# Also write a summary row
|
|
||||||
summary_path = csv_path.replace('.csv', '_summary.csv')
|
|
||||||
with open(summary_path, 'w', newline='') as f:
|
|
||||||
writer = csv.DictWriter(f, fieldnames=['stat'] + all_keys)
|
|
||||||
writer.writeheader()
|
|
||||||
for stat_name, stat_fn in [('mean', np.mean), ('std', np.std),
|
|
||||||
('min', np.min), ('max', np.max)]:
|
|
||||||
row = {'stat': stat_name}
|
|
||||||
row.update({k: f"{stat_fn([r.get(k, 0) for r in records]):.2f}"
|
|
||||||
for k in all_keys})
|
|
||||||
writer.writerow(row)
|
|
||||||
|
|
||||||
print(f">>> CSV written to: {csv_path}")
|
|
||||||
print(f">>> Summary written to: {summary_path}")
|
|
||||||
|
|
||||||
|
|
||||||
def compare_csvs(path_a, path_b):
|
|
||||||
"""Compare two summary CSVs and print a diff table."""
|
|
||||||
df_a = pd.read_csv(path_a, index_col='stat')
|
|
||||||
df_b = pd.read_csv(path_b, index_col='stat')
|
|
||||||
|
|
||||||
# Use mean row for comparison
|
|
||||||
mean_a = df_a.loc['mean'].astype(float)
|
|
||||||
mean_b = df_b.loc['mean'].astype(float)
|
|
||||||
|
|
||||||
print("=" * 90)
|
|
||||||
print(f"A/B COMPARISON: {os.path.basename(path_a)} vs {os.path.basename(path_b)}")
|
|
||||||
print("=" * 90)
|
|
||||||
print(f"{'Stage':<35} {'A(ms)':>10} {'B(ms)':>10} {'Diff':>10} {'Speedup':>10}")
|
|
||||||
print("-" * 90)
|
|
||||||
|
|
||||||
for col in mean_a.index:
|
|
||||||
if col not in mean_b.index:
|
|
||||||
continue
|
|
||||||
a_val = mean_a[col]
|
|
||||||
b_val = mean_b[col]
|
|
||||||
diff = b_val - a_val
|
|
||||||
speedup = a_val / b_val if b_val > 0 else float('inf')
|
|
||||||
marker = " <<<" if abs(diff) > 50 else ""
|
|
||||||
print(f"{col:<35} {a_val:>10.1f} {b_val:>10.1f} {diff:>+10.1f} {speedup:>9.2f}x{marker}")
|
|
||||||
|
|
||||||
print("-" * 90)
|
|
||||||
total_a = mean_a.get('itr_total', 0)
|
|
||||||
total_b = mean_b.get('itr_total', 0)
|
|
||||||
print(f"{'itr_total':<35} {total_a:>10.1f} {total_b:>10.1f} "
|
|
||||||
f"{total_b - total_a:>+10.1f} {total_a / total_b if total_b > 0 else 0:>9.2f}x")
|
|
||||||
print()
|
|
||||||
|
|
||||||
|
|
||||||
# ──────────────────────────────────────────────────────────────────────
|
|
||||||
# Layer 2: GPU timeline trace wrapper
|
|
||||||
# ──────────────────────────────────────────────────────────────────────
|
|
||||||
def run_with_trace(model, args, config, noise_shape, device):
|
|
||||||
"""Run iterations under torch.profiler to generate Chrome/TensorBoard traces."""
|
|
||||||
trace_dir = args.trace_dir
|
|
||||||
os.makedirs(trace_dir, exist_ok=True)
|
|
||||||
|
|
||||||
# We need the same data setup as run_profiled_iterations
|
|
||||||
csv_path = os.path.join(args.prompt_dir, f"{args.dataset}.csv")
|
|
||||||
df = pd.read_csv(csv_path)
|
|
||||||
sample = df.iloc[0]
|
|
||||||
|
|
||||||
data_module = instantiate_from_config(config.data)
|
|
||||||
data_module.setup()
|
|
||||||
|
|
||||||
init_frame_path = get_init_frame_path(args.prompt_dir, sample)
|
|
||||||
ori_fps = float(sample['fps'])
|
|
||||||
fs = args.frame_stride
|
|
||||||
model_input_fs = ori_fps // fs
|
|
||||||
|
|
||||||
transition_path = get_transition_path(args.prompt_dir, sample)
|
|
||||||
with h5py.File(transition_path, 'r') as h5f:
|
|
||||||
transition_dict = {}
|
|
||||||
for key in h5f.keys():
|
|
||||||
transition_dict[key] = torch.tensor(h5f[key][()])
|
|
||||||
for key in h5f.attrs.keys():
|
|
||||||
transition_dict[key] = h5f.attrs[key]
|
|
||||||
|
|
||||||
batch, ori_state_dim, ori_action_dim = prepare_init_input(
|
|
||||||
0, init_frame_path, transition_dict, fs,
|
|
||||||
data_module.test_datasets[args.dataset],
|
|
||||||
n_obs_steps=model.n_obs_steps_imagen)
|
|
||||||
|
|
||||||
observation = {
|
|
||||||
'observation.images.top':
|
|
||||||
batch['observation.image'].permute(1, 0, 2, 3)[-1].unsqueeze(0),
|
|
||||||
'observation.state':
|
|
||||||
batch['observation.state'][-1].unsqueeze(0),
|
|
||||||
'action':
|
|
||||||
torch.zeros_like(batch['action'][-1]).unsqueeze(0),
|
|
||||||
}
|
|
||||||
observation = {k: v.to(device, non_blocking=True) for k, v in observation.items()}
|
|
||||||
|
|
||||||
cond_obs_queues = {
|
|
||||||
"observation.images.top": deque(maxlen=model.n_obs_steps_imagen),
|
|
||||||
"observation.state": deque(maxlen=model.n_obs_steps_imagen),
|
|
||||||
"action": deque(maxlen=args.video_length),
|
|
||||||
}
|
|
||||||
cond_obs_queues = populate_queues(cond_obs_queues, observation)
|
|
||||||
|
|
||||||
tmp_dir = os.path.join(args.savedir, "profile_tmp")
|
|
||||||
os.makedirs(tmp_dir, exist_ok=True)
|
|
||||||
prompt_text = sample['instruction']
|
|
||||||
|
|
||||||
# Total iterations: warmup + active
|
|
||||||
n_warmup = 1
|
|
||||||
n_active = min(args.n_iter, 2) # trace 2 active iterations max
|
|
||||||
n_total = n_warmup + n_active
|
|
||||||
|
|
||||||
print(f">>> GPU trace: {n_warmup} warmup + {n_active} active iterations")
|
|
||||||
print(f">>> Trace output: {trace_dir}")
|
|
||||||
|
|
||||||
with torch.no_grad(), torch.profiler.profile(
|
|
||||||
activities=[
|
|
||||||
torch.profiler.ProfilerActivity.CPU,
|
|
||||||
torch.profiler.ProfilerActivity.CUDA,
|
|
||||||
],
|
|
||||||
schedule=torch.profiler.schedule(
|
|
||||||
wait=0, warmup=n_warmup, active=n_active, repeat=1),
|
|
||||||
on_trace_ready=torch.profiler.tensorboard_trace_handler(trace_dir),
|
|
||||||
record_shapes=True,
|
|
||||||
with_stack=True,
|
|
||||||
) as prof:
|
|
||||||
for itr_idx in range(n_total):
|
|
||||||
phase = "warmup" if itr_idx < n_warmup else "active"
|
|
||||||
print(f" trace itr {itr_idx} ({phase})...")
|
|
||||||
|
|
||||||
# ── One full iteration (same logic as run_inference) ──
|
|
||||||
obs_loc = {
|
|
||||||
'observation.images.top':
|
|
||||||
torch.stack(list(cond_obs_queues['observation.images.top']),
|
|
||||||
dim=1).permute(0, 2, 1, 3, 4),
|
|
||||||
'observation.state':
|
|
||||||
torch.stack(list(cond_obs_queues['observation.state']), dim=1),
|
|
||||||
'action':
|
|
||||||
torch.stack(list(cond_obs_queues['action']), dim=1),
|
|
||||||
}
|
|
||||||
obs_loc = {k: v.to(device) for k, v in obs_loc.items()}
|
|
||||||
|
|
||||||
# Policy pass
|
|
||||||
dummy_rec = defaultdict(list)
|
|
||||||
pv0, pa, _ = profiled_synthesis(
|
|
||||||
model, prompt_text, obs_loc, noise_shape,
|
|
||||||
ddim_steps=args.ddim_steps, ddim_eta=args.ddim_eta,
|
|
||||||
unconditional_guidance_scale=args.unconditional_guidance_scale,
|
|
||||||
fs=model_input_fs, text_input=True,
|
|
||||||
timestep_spacing=args.timestep_spacing,
|
|
||||||
guidance_rescale=args.guidance_rescale,
|
|
||||||
sim_mode=False,
|
|
||||||
decode_video=not args.fast_policy_no_decode,
|
|
||||||
records=dummy_rec, prefix="policy")
|
|
||||||
|
|
||||||
for idx in range(len(pa[0])):
|
|
||||||
oa = {'action': pa[0][idx:idx + 1]}
|
|
||||||
oa['action'][:, ori_action_dim:] = 0.0
|
|
||||||
populate_queues(cond_obs_queues, oa)
|
|
||||||
|
|
||||||
# Re-stack for world model
|
|
||||||
obs_loc2 = {
|
|
||||||
'observation.images.top':
|
|
||||||
torch.stack(list(cond_obs_queues['observation.images.top']),
|
|
||||||
dim=1).permute(0, 2, 1, 3, 4),
|
|
||||||
'observation.state':
|
|
||||||
torch.stack(list(cond_obs_queues['observation.state']), dim=1),
|
|
||||||
'action':
|
|
||||||
torch.stack(list(cond_obs_queues['action']), dim=1),
|
|
||||||
}
|
|
||||||
obs_loc2 = {k: v.to(device) for k, v in obs_loc2.items()}
|
|
||||||
|
|
||||||
# World model pass
|
|
||||||
pv1, _, ps = profiled_synthesis(
|
|
||||||
model, "", obs_loc2, noise_shape,
|
|
||||||
ddim_steps=args.ddim_steps, ddim_eta=args.ddim_eta,
|
|
||||||
unconditional_guidance_scale=args.unconditional_guidance_scale,
|
|
||||||
fs=model_input_fs, text_input=False,
|
|
||||||
timestep_spacing=args.timestep_spacing,
|
|
||||||
guidance_rescale=args.guidance_rescale,
|
|
||||||
sim_mode=True, decode_video=True,
|
|
||||||
records=dummy_rec, prefix="wm")
|
|
||||||
|
|
||||||
# Update obs queue
|
|
||||||
for idx in range(args.exe_steps):
|
|
||||||
ou = {
|
|
||||||
'observation.images.top':
|
|
||||||
pv1[0][:, idx:idx + 1].permute(1, 0, 2, 3),
|
|
||||||
'observation.state': ps[0][idx:idx + 1],
|
|
||||||
'action': torch.zeros_like(pa[0][-1:]),
|
|
||||||
}
|
|
||||||
ou['observation.state'][:, ori_state_dim:] = 0.0
|
|
||||||
populate_queues(cond_obs_queues, ou)
|
|
||||||
|
|
||||||
# Save results (captures CPU stall in trace)
|
|
||||||
if pv0 is not None:
|
|
||||||
save_results(pv0.cpu(),
|
|
||||||
os.path.join(tmp_dir, f"trace_dm_{itr_idx}.mp4"),
|
|
||||||
fps=args.save_fps)
|
|
||||||
save_results(pv1.cpu(),
|
|
||||||
os.path.join(tmp_dir, f"trace_wm_{itr_idx}.mp4"),
|
|
||||||
fps=args.save_fps)
|
|
||||||
|
|
||||||
prof.step()
|
|
||||||
|
|
||||||
print(f">>> Trace saved to {trace_dir}")
|
|
||||||
print(" View with: tensorboard --logdir", trace_dir)
|
|
||||||
print(" Or open the .json file in chrome://tracing")
|
|
||||||
|
|
||||||
|
|
||||||
# ──────────────────────────────────────────────────────────────────────
|
|
||||||
# Argument parser
|
|
||||||
# ──────────────────────────────────────────────────────────────────────
|
|
||||||
def get_parser():
|
|
||||||
p = argparse.ArgumentParser(description="Profile full iteration loop")
|
|
||||||
|
|
||||||
# Compare mode (no model needed)
|
|
||||||
p.add_argument("--compare", nargs=2, metavar=("A_SUMMARY", "B_SUMMARY"),
|
|
||||||
help="Compare two summary CSVs and exit")
|
|
||||||
|
|
||||||
# Model / data
|
|
||||||
p.add_argument("--ckpt_path", type=str, default=None)
|
|
||||||
p.add_argument("--config", type=str, default=None)
|
|
||||||
p.add_argument("--prompt_dir", type=str, default=None)
|
|
||||||
p.add_argument("--dataset", type=str, default=None)
|
|
||||||
p.add_argument("--savedir", type=str, default="profile_output")
|
|
||||||
|
|
||||||
# Inference params (match world_model_interaction.py)
|
|
||||||
p.add_argument("--ddim_steps", type=int, default=50)
|
|
||||||
p.add_argument("--ddim_eta", type=float, default=1.0)
|
|
||||||
p.add_argument("--bs", type=int, default=1)
|
|
||||||
p.add_argument("--height", type=int, default=320)
|
|
||||||
p.add_argument("--width", type=int, default=512)
|
|
||||||
p.add_argument("--frame_stride", type=int, default=4)
|
|
||||||
p.add_argument("--unconditional_guidance_scale", type=float, default=1.0)
|
|
||||||
p.add_argument("--video_length", type=int, default=16)
|
|
||||||
p.add_argument("--timestep_spacing", type=str, default="uniform_trailing")
|
|
||||||
p.add_argument("--guidance_rescale", type=float, default=0.7)
|
|
||||||
p.add_argument("--exe_steps", type=int, default=16)
|
|
||||||
p.add_argument("--n_iter", type=int, default=5)
|
|
||||||
p.add_argument("--save_fps", type=int, default=8)
|
|
||||||
p.add_argument("--seed", type=int, default=123)
|
|
||||||
p.add_argument("--perframe_ae", action='store_true', default=False)
|
|
||||||
p.add_argument("--vae_dtype", type=str, choices=["fp32", "bf16"], default="bf16")
|
|
||||||
p.add_argument("--fast_policy_no_decode", action='store_true', default=False)
|
|
||||||
|
|
||||||
# Profiling control
|
|
||||||
p.add_argument("--warmup", type=int, default=1,
|
|
||||||
help="Number of warmup iterations to skip in statistics")
|
|
||||||
p.add_argument("--csv", type=str, default=None,
|
|
||||||
help="Write per-iteration timing to this CSV file")
|
|
||||||
p.add_argument("--trace", action='store_true', default=False,
|
|
||||||
help="Enable Layer 2: GPU timeline trace")
|
|
||||||
p.add_argument("--trace_dir", type=str, default="./profile_traces",
|
|
||||||
help="Directory for trace output")
|
|
||||||
|
|
||||||
return p
|
|
||||||
|
|
||||||
|
|
||||||
# ──────────────────────────────────────────────────────────────────────
|
|
||||||
# Main
|
|
||||||
# ──────────────────────────────────────────────────────────────────────
|
|
||||||
def main():
|
|
||||||
patch_norm_bypass_autocast()
|
|
||||||
parser = get_parser()
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
# ── Compare mode: no model needed ──
|
|
||||||
if args.compare:
|
|
||||||
compare_csvs(args.compare[0], args.compare[1])
|
|
||||||
return
|
|
||||||
|
|
||||||
# ── Validate required args ──
|
|
||||||
for required in ['ckpt_path', 'config', 'prompt_dir', 'dataset']:
|
|
||||||
if getattr(args, required) is None:
|
|
||||||
parser.error(f"--{required} is required for profiling mode")
|
|
||||||
|
|
||||||
seed_everything(args.seed)
|
|
||||||
os.makedirs(args.savedir, exist_ok=True)
|
|
||||||
|
|
||||||
# ── Load model ──
|
|
||||||
print("=" * 60)
|
|
||||||
print("PROFILE ITERATION — Loading model...")
|
|
||||||
print("=" * 60)
|
|
||||||
model, config = load_model(args)
|
|
||||||
device = next(model.parameters()).device
|
|
||||||
|
|
||||||
h, w = args.height // 8, args.width // 8
|
|
||||||
channels = model.model.diffusion_model.out_channels
|
|
||||||
noise_shape = [args.bs, channels, args.video_length, h, w]
|
|
||||||
print(f">>> Noise shape: {noise_shape}")
|
|
||||||
print(f">>> DDIM steps: {args.ddim_steps}")
|
|
||||||
print(f">>> fast_policy_no_decode: {args.fast_policy_no_decode}")
|
|
||||||
|
|
||||||
# ── Layer 2: GPU trace (optional) ──
|
|
||||||
if args.trace:
|
|
||||||
with torch.no_grad():
|
|
||||||
run_with_trace(model, args, config, noise_shape, device)
|
|
||||||
print()
|
|
||||||
|
|
||||||
# ── Layer 1: Iteration-level breakdown ──
|
|
||||||
print("=" * 60)
|
|
||||||
print("LAYER 1: ITERATION-LEVEL PROFILING")
|
|
||||||
print("=" * 60)
|
|
||||||
with torch.no_grad():
|
|
||||||
all_records = run_profiled_iterations(
|
|
||||||
model, args, config, noise_shape, device)
|
|
||||||
|
|
||||||
# Print report
|
|
||||||
print_iteration_report(all_records, warmup=args.warmup)
|
|
||||||
|
|
||||||
# ── Layer 3: CSV output for A/B comparison ──
|
|
||||||
if args.csv:
|
|
||||||
write_csv(all_records, args.csv, warmup=args.warmup)
|
|
||||||
|
|
||||||
print("Done.")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
main()
|
|
||||||
@@ -1,733 +0,0 @@
|
|||||||
"""
|
|
||||||
Profile the full inference pipeline of the world model, covering all 7 stages:
|
|
||||||
1. Image Embedding
|
|
||||||
2. VAE Encode
|
|
||||||
3. Text Conditioning
|
|
||||||
4. State/Action Projectors
|
|
||||||
5. DDIM Loop
|
|
||||||
6. VAE Decode
|
|
||||||
7. Post-process
|
|
||||||
|
|
||||||
Reports stage-level timing, UNet sub-module breakdown, memory summary,
|
|
||||||
and throughput analysis.
|
|
||||||
TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 CUDA_VISIBLE_DEVICES=0 python scripts/evaluation/profile_pipeline.py --ckpt_path ckpts/unifolm_wma_dual_mix_bf16.ckpt --config configs/inference/world_model_interaction.yaml --ddim_steps 50 --cfg_scale 1.0 --n_runs 3 --deep
|
|
||||||
Usage:
|
|
||||||
TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 CUDA_VISIBLE_DEVICES=0 python scripts/evaluation/profile_pipeline.py --ckpt_path ckpts/unifolm_wma_dual_mix_bf16.ckpt --config configs/inference/world_model_interaction.yaml --ddim_steps 50 --cfg_scale 1.0 --n_runs 3
|
|
||||||
"""
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
import torch.utils.checkpoint # must be loaded before unifolm_wma.utils.common
|
|
||||||
from contextlib import nullcontext, contextmanager
|
|
||||||
from collections import defaultdict
|
|
||||||
from omegaconf import OmegaConf
|
|
||||||
from einops import rearrange, repeat
|
|
||||||
|
|
||||||
from unifolm_wma.utils.utils import instantiate_from_config
|
|
||||||
from unifolm_wma.models.samplers.ddim import DDIMSampler
|
|
||||||
from unifolm_wma.modules.attention import (
|
|
||||||
SpatialTransformer, TemporalTransformer,
|
|
||||||
BasicTransformerBlock, CrossAttention, FeedForward,
|
|
||||||
)
|
|
||||||
from unifolm_wma.modules.networks.wma_model import ResBlock
|
|
||||||
from unifolm_wma.models.diffusion_head.conditional_unet1d import ConditionalUnet1D
|
|
||||||
|
|
||||||
# --- W7900D theoretical peak ---
|
|
||||||
PEAK_BF16_TFLOPS = 61.0
|
|
||||||
MEM_BW_GBS = 864.0
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Utility: patch norms to bypass autocast fp32 promotion
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
def patch_norm_bypass_autocast():
|
|
||||||
"""Monkey-patch GroupNorm and LayerNorm to bypass autocast's fp32 policy."""
|
|
||||||
|
|
||||||
def _group_norm_forward(self, x):
|
|
||||||
with torch.amp.autocast('cuda', enabled=False):
|
|
||||||
return F.group_norm(
|
|
||||||
x, self.num_groups,
|
|
||||||
self.weight.to(x.dtype) if self.weight is not None else None,
|
|
||||||
self.bias.to(x.dtype) if self.bias is not None else None,
|
|
||||||
self.eps)
|
|
||||||
|
|
||||||
def _layer_norm_forward(self, x):
|
|
||||||
with torch.amp.autocast('cuda', enabled=False):
|
|
||||||
return F.layer_norm(
|
|
||||||
x, self.normalized_shape,
|
|
||||||
self.weight.to(x.dtype) if self.weight is not None else None,
|
|
||||||
self.bias.to(x.dtype) if self.bias is not None else None,
|
|
||||||
self.eps)
|
|
||||||
|
|
||||||
torch.nn.GroupNorm.forward = _group_norm_forward
|
|
||||||
torch.nn.LayerNorm.forward = _layer_norm_forward
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Utility: torch.compile hot ResBlocks
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
def apply_torch_compile(model, hot_indices=(5, 8, 9)):
|
|
||||||
"""Compile ResBlock._forward in the hottest output_blocks for operator fusion."""
|
|
||||||
unet = model.model.diffusion_model
|
|
||||||
compiled = 0
|
|
||||||
for idx in hot_indices:
|
|
||||||
block = unet.output_blocks[idx]
|
|
||||||
for layer in block:
|
|
||||||
if isinstance(layer, ResBlock):
|
|
||||||
layer._forward = torch.compile(layer._forward, mode="default")
|
|
||||||
compiled += 1
|
|
||||||
print(f" torch.compile: {compiled} ResBlocks in output_blocks{list(hot_indices)}")
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Model loading
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
def load_model(args):
|
|
||||||
config = OmegaConf.load(args.config)
|
|
||||||
config['model']['params']['wma_config']['params']['use_checkpoint'] = False
|
|
||||||
model = instantiate_from_config(config.model)
|
|
||||||
|
|
||||||
state_dict = torch.load(args.ckpt_path, map_location="cpu")
|
|
||||||
if "state_dict" in state_dict:
|
|
||||||
state_dict = state_dict["state_dict"]
|
|
||||||
model.load_state_dict(state_dict, strict=True)
|
|
||||||
|
|
||||||
model.eval()
|
|
||||||
model.model.to(torch.bfloat16)
|
|
||||||
model.diffusion_autocast_dtype = torch.bfloat16
|
|
||||||
apply_torch_compile(model)
|
|
||||||
model = model.cuda()
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# CudaTimer — precise GPU timing via CUDA events
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
class CudaTimer:
|
|
||||||
"""Context manager for GPU-precise stage timing using CUDA events."""
|
|
||||||
|
|
||||||
def __init__(self, name, records):
|
|
||||||
self.name = name
|
|
||||||
self.records = records
|
|
||||||
self.start = torch.cuda.Event(enable_timing=True)
|
|
||||||
self.end = torch.cuda.Event(enable_timing=True)
|
|
||||||
|
|
||||||
def __enter__(self):
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
self.start.record()
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __exit__(self, *args):
|
|
||||||
self.end.record()
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
elapsed = self.start.elapsed_time(self.end)
|
|
||||||
self.records[self.name].append(elapsed)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# HookProfiler — sub-module level timing inside UNet via hooks
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
class HookProfiler:
|
|
||||||
"""Register forward hooks on UNet sub-modules to collect per-call timing."""
|
|
||||||
|
|
||||||
# Coarse-grained targets (original)
|
|
||||||
COARSE_CLASSES = (
|
|
||||||
SpatialTransformer,
|
|
||||||
TemporalTransformer,
|
|
||||||
ResBlock,
|
|
||||||
ConditionalUnet1D,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Fine-grained targets for deep DDIM analysis
|
|
||||||
FINE_CLASSES = (
|
|
||||||
CrossAttention,
|
|
||||||
FeedForward,
|
|
||||||
)
|
|
||||||
|
|
||||||
def __init__(self, unet, deep=False):
|
|
||||||
self.unet = unet
|
|
||||||
self.deep = deep
|
|
||||||
self.handles = []
|
|
||||||
# per-instance data: {instance_id: [(start_event, end_event), ...]}
|
|
||||||
self._events = defaultdict(list)
|
|
||||||
# tag mapping: {instance_id: (class_name, module_name)}
|
|
||||||
self._tags = {}
|
|
||||||
# block location: {instance_id: block_location_str}
|
|
||||||
self._block_loc = {}
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _get_block_location(name):
|
|
||||||
"""Derive UNet block location from module name, e.g. 'input_blocks.3.1'."""
|
|
||||||
parts = name.split('.')
|
|
||||||
if len(parts) >= 2 and parts[0] == 'input_blocks':
|
|
||||||
return f"input_blocks.{parts[1]}"
|
|
||||||
elif len(parts) >= 1 and parts[0] == 'middle_block':
|
|
||||||
return "middle_block"
|
|
||||||
elif len(parts) >= 2 and parts[0] == 'output_blocks':
|
|
||||||
return f"output_blocks.{parts[1]}"
|
|
||||||
elif 'action_unet' in name:
|
|
||||||
return "action_unet"
|
|
||||||
elif 'state_unet' in name:
|
|
||||||
return "state_unet"
|
|
||||||
elif name == 'out' or name.startswith('out.'):
|
|
||||||
return "out"
|
|
||||||
return "other"
|
|
||||||
|
|
||||||
def register(self):
|
|
||||||
"""Attach pre/post forward hooks to target sub-modules + unet.out."""
|
|
||||||
target_classes = self.COARSE_CLASSES
|
|
||||||
if self.deep:
|
|
||||||
target_classes = target_classes + self.FINE_CLASSES
|
|
||||||
|
|
||||||
for name, mod in self.unet.named_modules():
|
|
||||||
if isinstance(mod, target_classes):
|
|
||||||
tag = type(mod).__name__
|
|
||||||
inst_id = id(mod)
|
|
||||||
self._tags[inst_id] = (tag, name)
|
|
||||||
self._block_loc[inst_id] = self._get_block_location(name)
|
|
||||||
self.handles.append(
|
|
||||||
mod.register_forward_pre_hook(self._make_pre_hook(inst_id)))
|
|
||||||
self.handles.append(
|
|
||||||
mod.register_forward_hook(self._make_post_hook(inst_id)))
|
|
||||||
|
|
||||||
# Also hook unet.out (nn.Sequential)
|
|
||||||
out_mod = self.unet.out
|
|
||||||
inst_id = id(out_mod)
|
|
||||||
self._tags[inst_id] = ("UNet.out", "out")
|
|
||||||
self._block_loc[inst_id] = "out"
|
|
||||||
self.handles.append(
|
|
||||||
out_mod.register_forward_pre_hook(self._make_pre_hook(inst_id)))
|
|
||||||
self.handles.append(
|
|
||||||
out_mod.register_forward_hook(self._make_post_hook(inst_id)))
|
|
||||||
|
|
||||||
def _make_pre_hook(self, inst_id):
|
|
||||||
events = self._events
|
|
||||||
|
|
||||||
def hook(module, input):
|
|
||||||
start = torch.cuda.Event(enable_timing=True)
|
|
||||||
start.record()
|
|
||||||
events[inst_id].append([start, None])
|
|
||||||
return hook
|
|
||||||
|
|
||||||
def _make_post_hook(self, inst_id):
|
|
||||||
events = self._events
|
|
||||||
|
|
||||||
def hook(module, input, output):
|
|
||||||
end = torch.cuda.Event(enable_timing=True)
|
|
||||||
end.record()
|
|
||||||
events[inst_id][-1][1] = end
|
|
||||||
return hook
|
|
||||||
|
|
||||||
def reset(self):
|
|
||||||
"""Clear collected events for a fresh run."""
|
|
||||||
self._events.clear()
|
|
||||||
|
|
||||||
def synchronize_and_collect(self):
|
|
||||||
"""Sync GPU and compute elapsed times. Returns (by_type, by_instance, by_block)."""
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
by_type = defaultdict(lambda: {"total_ms": 0.0, "count": 0, "calls": []})
|
|
||||||
by_instance = {}
|
|
||||||
# by_block: {block_loc: {tag: {"total_ms", "count"}}}
|
|
||||||
by_block = defaultdict(lambda: defaultdict(lambda: {"total_ms": 0.0, "count": 0}))
|
|
||||||
|
|
||||||
for inst_id, pairs in self._events.items():
|
|
||||||
tag, mod_name = self._tags[inst_id]
|
|
||||||
block_loc = self._block_loc.get(inst_id, "other")
|
|
||||||
inst_times = []
|
|
||||||
for start_evt, end_evt in pairs:
|
|
||||||
if end_evt is not None:
|
|
||||||
ms = start_evt.elapsed_time(end_evt)
|
|
||||||
inst_times.append(ms)
|
|
||||||
by_type[tag]["total_ms"] += ms
|
|
||||||
by_type[tag]["count"] += 1
|
|
||||||
by_type[tag]["calls"].append(ms)
|
|
||||||
by_block[block_loc][tag]["total_ms"] += ms
|
|
||||||
by_block[block_loc][tag]["count"] += 1
|
|
||||||
by_instance[(tag, mod_name)] = inst_times
|
|
||||||
|
|
||||||
return dict(by_type), by_instance, dict(by_block)
|
|
||||||
|
|
||||||
def remove(self):
|
|
||||||
"""Remove all hooks."""
|
|
||||||
for h in self.handles:
|
|
||||||
h.remove()
|
|
||||||
self.handles.clear()
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Build dummy inputs matching the pipeline's expected shapes
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
def build_dummy_inputs(model, noise_shape):
|
|
||||||
"""Create synthetic observation dict and prompts for profiling."""
|
|
||||||
device = next(model.parameters()).device
|
|
||||||
B, C, T, H, W = noise_shape
|
|
||||||
dtype = torch.bfloat16
|
|
||||||
|
|
||||||
# observation.images.top: [B, C, O, H, W] (permuted to [B,O,C,H,W] inside pipeline)
|
|
||||||
O = 2
|
|
||||||
obs_images = torch.randn(B, 3, O, 320, 512, device=device, dtype=dtype)
|
|
||||||
obs_state = torch.randn(B, O, 16, device=device, dtype=dtype)
|
|
||||||
action = torch.randn(B, 16, 16, device=device, dtype=dtype)
|
|
||||||
|
|
||||||
observation = {
|
|
||||||
'observation.images.top': obs_images,
|
|
||||||
'observation.state': obs_state,
|
|
||||||
'action': action,
|
|
||||||
}
|
|
||||||
prompts = ["a robot arm performing a task"] * B
|
|
||||||
return observation, prompts
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Run one full pipeline pass with per-stage timing
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
def run_pipeline(model, observation, prompts, noise_shape, ddim_steps,
|
|
||||||
cfg_scale, hook_profiler):
|
|
||||||
"""Execute the full 7-stage pipeline, returning per-stage timing dict."""
|
|
||||||
records = defaultdict(list)
|
|
||||||
device = next(model.parameters()).device
|
|
||||||
B, C, T, H, W = noise_shape
|
|
||||||
dtype = torch.bfloat16
|
|
||||||
|
|
||||||
fs = torch.tensor([1] * B, dtype=torch.long, device=device)
|
|
||||||
|
|
||||||
# --- Stage 1: Image Embedding ---
|
|
||||||
with CudaTimer("1_Image_Embedding", records):
|
|
||||||
img = observation['observation.images.top'].permute(0, 2, 1, 3, 4)
|
|
||||||
cond_img = rearrange(img, 'b o c h w -> (b o) c h w')[-1:].to(dtype=dtype)
|
|
||||||
with torch.autocast('cuda', dtype=torch.bfloat16):
|
|
||||||
cond_img_emb = model.embedder(cond_img)
|
|
||||||
cond_img_emb = model.image_proj_model(cond_img_emb)
|
|
||||||
|
|
||||||
# --- Stage 2: VAE Encode ---
|
|
||||||
with CudaTimer("2_VAE_Encode", records):
|
|
||||||
videos = img.permute(0, 2, 1, 3, 4) # [B, C, O, H, W]
|
|
||||||
b_v, c_v, t_v, h_v, w_v = videos.shape
|
|
||||||
vae_dtype = next(model.first_stage_model.parameters()).dtype
|
|
||||||
x_vae = rearrange(videos, 'b c t h w -> (b t) c h w').to(dtype=vae_dtype)
|
|
||||||
z = model.encode_first_stage(x_vae)
|
|
||||||
z = rearrange(z, '(b t) c h w -> b c t h w', b=b_v, t=t_v)
|
|
||||||
img_cat_cond = z[:, :, -1:, :, :]
|
|
||||||
img_cat_cond = repeat(img_cat_cond,
|
|
||||||
'b c t h w -> b c (repeat t) h w', repeat=T)
|
|
||||||
cond = {"c_concat": [img_cat_cond]}
|
|
||||||
|
|
||||||
vae_enc_input_bytes = x_vae.nelement() * x_vae.element_size()
|
|
||||||
vae_enc_output_bytes = z.nelement() * z.element_size()
|
|
||||||
|
|
||||||
# --- Stage 3: Text Conditioning ---
|
|
||||||
with CudaTimer("3_Text_Conditioning", records):
|
|
||||||
cond_ins_emb = model.get_learned_conditioning(prompts)
|
|
||||||
|
|
||||||
# --- Stage 4: State/Action Projectors ---
|
|
||||||
with CudaTimer("4_Projectors", records):
|
|
||||||
projector_dtype = next(model.state_projector.parameters()).dtype
|
|
||||||
with torch.autocast('cuda', dtype=torch.bfloat16):
|
|
||||||
cond_state_emb = model.state_projector(
|
|
||||||
observation['observation.state'].to(dtype=projector_dtype))
|
|
||||||
cond_state_emb = cond_state_emb + model.agent_state_pos_emb
|
|
||||||
|
|
||||||
cond_action_emb = model.action_projector(
|
|
||||||
observation['action'].to(dtype=projector_dtype))
|
|
||||||
cond_action_emb = cond_action_emb + model.agent_action_pos_emb
|
|
||||||
|
|
||||||
# Assemble cross-attention conditioning
|
|
||||||
cond["c_crossattn"] = [
|
|
||||||
torch.cat([cond_state_emb, cond_action_emb, cond_ins_emb, cond_img_emb],
|
|
||||||
dim=1)
|
|
||||||
]
|
|
||||||
n_obs_acting = getattr(model, 'n_obs_steps_acting', 2)
|
|
||||||
cond["c_crossattn_action"] = [
|
|
||||||
observation['observation.images.top'][:, :, -n_obs_acting:],
|
|
||||||
observation['observation.state'][:, -n_obs_acting:],
|
|
||||||
True, # sim_mode
|
|
||||||
False,
|
|
||||||
]
|
|
||||||
|
|
||||||
# CFG: build unconditional conditioning if needed
|
|
||||||
uc = None
|
|
||||||
if cfg_scale != 1.0:
|
|
||||||
uc_crossattn = torch.zeros_like(cond["c_crossattn"][0])
|
|
||||||
uc = {
|
|
||||||
"c_concat": cond["c_concat"],
|
|
||||||
"c_crossattn": [uc_crossattn],
|
|
||||||
"c_crossattn_action": cond["c_crossattn_action"],
|
|
||||||
}
|
|
||||||
|
|
||||||
# --- Stage 5: DDIM Loop ---
|
|
||||||
ddim_sampler = DDIMSampler(model)
|
|
||||||
hook_profiler.reset()
|
|
||||||
|
|
||||||
with CudaTimer("5_DDIM_Loop", records):
|
|
||||||
with torch.autocast('cuda', dtype=torch.bfloat16):
|
|
||||||
samples, actions, states, _ = ddim_sampler.sample(
|
|
||||||
S=ddim_steps,
|
|
||||||
conditioning=cond,
|
|
||||||
batch_size=B,
|
|
||||||
shape=noise_shape[1:],
|
|
||||||
verbose=False,
|
|
||||||
unconditional_guidance_scale=cfg_scale,
|
|
||||||
unconditional_conditioning=uc,
|
|
||||||
eta=1.0,
|
|
||||||
cfg_img=None,
|
|
||||||
mask=None,
|
|
||||||
x0=None,
|
|
||||||
fs=fs,
|
|
||||||
timestep_spacing='uniform',
|
|
||||||
guidance_rescale=0.0,
|
|
||||||
unconditional_conditioning_img_nonetext=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
hook_by_type, hook_by_instance, hook_by_block = hook_profiler.synchronize_and_collect()
|
|
||||||
|
|
||||||
# --- Stage 6: VAE Decode ---
|
|
||||||
with CudaTimer("6_VAE_Decode", records):
|
|
||||||
batch_images = model.decode_first_stage(samples)
|
|
||||||
|
|
||||||
vae_dec_input_bytes = samples.nelement() * samples.element_size()
|
|
||||||
vae_dec_output_bytes = batch_images.nelement() * batch_images.element_size()
|
|
||||||
|
|
||||||
# --- Stage 7: Post-process ---
|
|
||||||
with CudaTimer("7_Post_Process", records):
|
|
||||||
batch_images_cpu = batch_images.cpu()
|
|
||||||
actions_cpu = actions.cpu()
|
|
||||||
states_cpu = states.cpu()
|
|
||||||
# Simulate video save overhead: clamp + uint8 conversion
|
|
||||||
_ = (batch_images_cpu.clamp(-1, 1) * 127.5 + 127.5).to(torch.uint8)
|
|
||||||
|
|
||||||
# Flatten single-element lists
|
|
||||||
stage_times = {k: v[0] for k, v in records.items()}
|
|
||||||
|
|
||||||
bandwidth_info = {
|
|
||||||
"vae_enc_input_bytes": vae_enc_input_bytes,
|
|
||||||
"vae_enc_output_bytes": vae_enc_output_bytes,
|
|
||||||
"vae_dec_input_bytes": vae_dec_input_bytes,
|
|
||||||
"vae_dec_output_bytes": vae_dec_output_bytes,
|
|
||||||
}
|
|
||||||
|
|
||||||
return stage_times, hook_by_type, hook_by_instance, hook_by_block, bandwidth_info
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Reporting
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
def print_stage_timing(all_runs_stages):
|
|
||||||
"""Table 1: Stage Timing — name | mean(ms) | std | percent."""
|
|
||||||
import numpy as np
|
|
||||||
stage_names = list(all_runs_stages[0].keys())
|
|
||||||
means = {}
|
|
||||||
stds = {}
|
|
||||||
for name in stage_names:
|
|
||||||
vals = [run[name] for run in all_runs_stages]
|
|
||||||
means[name] = np.mean(vals)
|
|
||||||
stds[name] = np.std(vals)
|
|
||||||
total = sum(means.values())
|
|
||||||
|
|
||||||
print()
|
|
||||||
print("=" * 72)
|
|
||||||
print("TABLE 1: STAGE TIMING")
|
|
||||||
print("=" * 72)
|
|
||||||
print(f"{'Stage':<25} {'Mean(ms)':>10} {'Std':>10} {'%':>8}")
|
|
||||||
print("-" * 72)
|
|
||||||
for name in stage_names:
|
|
||||||
pct = means[name] / total * 100 if total > 0 else 0
|
|
||||||
print(f"{name:<25} {means[name]:>10.1f} {stds[name]:>10.2f} {pct:>7.1f}%")
|
|
||||||
print("-" * 72)
|
|
||||||
print(f"{'TOTAL':<25} {total:>10.1f}")
|
|
||||||
print()
|
|
||||||
|
|
||||||
|
|
||||||
def print_unet_breakdown(all_runs_hooks):
|
|
||||||
"""Table 2: UNet Sub-Module Breakdown — type | total | count | per-call | percent."""
|
|
||||||
import numpy as np
|
|
||||||
# Aggregate across runs
|
|
||||||
agg = defaultdict(lambda: {"totals": [], "counts": []})
|
|
||||||
for hook_by_type in all_runs_hooks:
|
|
||||||
for tag, data in hook_by_type.items():
|
|
||||||
agg[tag]["totals"].append(data["total_ms"])
|
|
||||||
agg[tag]["counts"].append(data["count"])
|
|
||||||
|
|
||||||
print("=" * 80)
|
|
||||||
print("TABLE 2: UNET SUB-MODULE BREAKDOWN")
|
|
||||||
print("=" * 80)
|
|
||||||
print(f"{'Module Type':<25} {'Total(ms)':>10} {'Count':>7} {'Per-call':>10} {'%':>8}")
|
|
||||||
print("-" * 80)
|
|
||||||
|
|
||||||
grand_total = 0
|
|
||||||
rows = []
|
|
||||||
for tag, d in agg.items():
|
|
||||||
mean_total = np.mean(d["totals"])
|
|
||||||
mean_count = np.mean(d["counts"])
|
|
||||||
per_call = mean_total / mean_count if mean_count > 0 else 0
|
|
||||||
grand_total += mean_total
|
|
||||||
rows.append((tag, mean_total, mean_count, per_call))
|
|
||||||
|
|
||||||
rows.sort(key=lambda r: r[1], reverse=True)
|
|
||||||
for tag, mean_total, mean_count, per_call in rows:
|
|
||||||
pct = mean_total / grand_total * 100 if grand_total > 0 else 0
|
|
||||||
print(f"{tag:<25} {mean_total:>10.1f} {int(mean_count):>7} {per_call:>10.2f} {pct:>7.1f}%")
|
|
||||||
print("-" * 80)
|
|
||||||
print(f"{'TOTAL (hooked)':<25} {grand_total:>10.1f}")
|
|
||||||
print()
|
|
||||||
|
|
||||||
|
|
||||||
def print_block_timing(all_runs_blocks):
|
|
||||||
"""Table 2b: Per-UNet-block timing — which blocks are hottest."""
|
|
||||||
import numpy as np
|
|
||||||
# Aggregate: {block_loc: {tag: [total_ms_per_run, ...]}}
|
|
||||||
agg = defaultdict(lambda: defaultdict(list))
|
|
||||||
for by_block in all_runs_blocks:
|
|
||||||
for block_loc, tag_dict in by_block.items():
|
|
||||||
for tag, data in tag_dict.items():
|
|
||||||
agg[block_loc][tag].append(data["total_ms"])
|
|
||||||
|
|
||||||
# Compute per-block totals
|
|
||||||
block_totals = {}
|
|
||||||
for block_loc, tag_dict in agg.items():
|
|
||||||
block_totals[block_loc] = sum(np.mean(v) for v in tag_dict.values())
|
|
||||||
|
|
||||||
grand_total = sum(block_totals.values())
|
|
||||||
|
|
||||||
# Sort blocks in logical order
|
|
||||||
def block_sort_key(name):
|
|
||||||
if name.startswith("input_blocks."):
|
|
||||||
return (0, int(name.split('.')[1]))
|
|
||||||
elif name == "middle_block":
|
|
||||||
return (1, 0)
|
|
||||||
elif name.startswith("output_blocks."):
|
|
||||||
return (2, int(name.split('.')[1]))
|
|
||||||
elif name == "out":
|
|
||||||
return (3, 0)
|
|
||||||
elif name == "action_unet":
|
|
||||||
return (4, 0)
|
|
||||||
elif name == "state_unet":
|
|
||||||
return (5, 0)
|
|
||||||
return (9, 0)
|
|
||||||
|
|
||||||
sorted_blocks = sorted(block_totals.keys(), key=block_sort_key)
|
|
||||||
|
|
||||||
print("=" * 90)
|
|
||||||
print("TABLE 2b: PER-UNET-BLOCK TIMING (coarse modules, per DDIM loop)")
|
|
||||||
print("=" * 90)
|
|
||||||
print(f"{'Block':<22} {'Total(ms)':>10} {'%':>7} Breakdown")
|
|
||||||
print("-" * 90)
|
|
||||||
|
|
||||||
for block_loc in sorted_blocks:
|
|
||||||
total = block_totals[block_loc]
|
|
||||||
pct = total / grand_total * 100 if grand_total > 0 else 0
|
|
||||||
# Build breakdown string
|
|
||||||
parts = []
|
|
||||||
for tag, vals in sorted(agg[block_loc].items(),
|
|
||||||
key=lambda x: np.mean(x[1]), reverse=True):
|
|
||||||
parts.append(f"{tag}={np.mean(vals):.0f}")
|
|
||||||
breakdown = ", ".join(parts)
|
|
||||||
print(f"{block_loc:<22} {total:>10.1f} {pct:>6.1f}% {breakdown}")
|
|
||||||
|
|
||||||
print("-" * 90)
|
|
||||||
print(f"{'TOTAL':<22} {grand_total:>10.1f}")
|
|
||||||
print()
|
|
||||||
|
|
||||||
|
|
||||||
def print_attn_ff_breakdown(all_runs_hooks):
|
|
||||||
"""Table 2c: CrossAttention vs FeedForward breakdown (--deep mode)."""
|
|
||||||
import numpy as np
|
|
||||||
agg = defaultdict(list)
|
|
||||||
for hook_by_type in all_runs_hooks:
|
|
||||||
for tag, data in hook_by_type.items():
|
|
||||||
if tag in ("CrossAttention", "FeedForward"):
|
|
||||||
agg[tag].append(data["total_ms"])
|
|
||||||
|
|
||||||
if not agg:
|
|
||||||
return
|
|
||||||
|
|
||||||
print("=" * 70)
|
|
||||||
print("TABLE 2c: ATTENTION vs FEEDFORWARD (deep hooks)")
|
|
||||||
print("=" * 70)
|
|
||||||
print(f"{'Component':<25} {'Total(ms)':>10} {'%':>8}")
|
|
||||||
print("-" * 70)
|
|
||||||
|
|
||||||
grand = 0
|
|
||||||
rows = []
|
|
||||||
for tag in ("CrossAttention", "FeedForward"):
|
|
||||||
if tag in agg:
|
|
||||||
mean_t = np.mean(agg[tag])
|
|
||||||
grand += mean_t
|
|
||||||
rows.append((tag, mean_t))
|
|
||||||
|
|
||||||
for tag, mean_t in rows:
|
|
||||||
pct = mean_t / grand * 100 if grand > 0 else 0
|
|
||||||
print(f"{tag:<25} {mean_t:>10.1f} {pct:>7.1f}%")
|
|
||||||
print("-" * 70)
|
|
||||||
print(f"{'TOTAL (attn+ff)':<25} {grand:>10.1f}")
|
|
||||||
print()
|
|
||||||
|
|
||||||
|
|
||||||
def print_unet_detailed(all_runs_instances):
|
|
||||||
"""Print per-instance UNet sub-module detail (--detailed mode)."""
|
|
||||||
import numpy as np
|
|
||||||
# Use last run's data
|
|
||||||
by_instance = all_runs_instances[-1]
|
|
||||||
print("=" * 100)
|
|
||||||
print("DETAILED: PER-INSTANCE UNET SUB-MODULE TIMING (last run)")
|
|
||||||
print("=" * 100)
|
|
||||||
print(f"{'Type':<22} {'Module Name':<45} {'Calls':>6} {'Total(ms)':>10} {'Mean(ms)':>10}")
|
|
||||||
print("-" * 100)
|
|
||||||
|
|
||||||
rows = []
|
|
||||||
for (tag, mod_name), times in by_instance.items():
|
|
||||||
if len(times) == 0:
|
|
||||||
continue
|
|
||||||
total = sum(times)
|
|
||||||
mean = np.mean(times)
|
|
||||||
rows.append((tag, mod_name, len(times), total, mean))
|
|
||||||
rows.sort(key=lambda r: r[3], reverse=True)
|
|
||||||
|
|
||||||
for tag, mod_name, count, total, mean in rows:
|
|
||||||
short_name = mod_name[-42:] if len(mod_name) > 42 else mod_name
|
|
||||||
print(f"{tag:<22} {short_name:<45} {count:>6} {total:>10.2f} {mean:>10.3f}")
|
|
||||||
print()
|
|
||||||
|
|
||||||
|
|
||||||
def print_memory_summary(mem_before, mem_peak):
|
|
||||||
"""Table 3: Memory Summary."""
|
|
||||||
delta = mem_peak - mem_before
|
|
||||||
print("=" * 50)
|
|
||||||
print("TABLE 3: MEMORY SUMMARY")
|
|
||||||
print("=" * 50)
|
|
||||||
print(f" Initial allocated: {mem_before / 1e9:.2f} GB")
|
|
||||||
print(f" Peak allocated: {mem_peak / 1e9:.2f} GB")
|
|
||||||
print(f" Delta (pipeline): {delta / 1e9:.2f} GB")
|
|
||||||
print()
|
|
||||||
|
|
||||||
|
|
||||||
def print_throughput(all_runs_stages, all_bw, ddim_steps, cfg_scale):
|
|
||||||
"""Table 4: Throughput — total latency, per-step, per-UNet-forward, VAE bandwidth."""
|
|
||||||
import numpy as np
|
|
||||||
n_runs = len(all_runs_stages)
|
|
||||||
|
|
||||||
# Total latency
|
|
||||||
totals = []
|
|
||||||
for run in all_runs_stages:
|
|
||||||
totals.append(sum(run.values()))
|
|
||||||
mean_total = np.mean(totals)
|
|
||||||
|
|
||||||
# DDIM loop time
|
|
||||||
ddim_times = [run["5_DDIM_Loop"] for run in all_runs_stages]
|
|
||||||
mean_ddim = np.mean(ddim_times)
|
|
||||||
|
|
||||||
unet_calls = ddim_steps if cfg_scale == 1.0 else ddim_steps * 2
|
|
||||||
per_step = mean_ddim / ddim_steps
|
|
||||||
per_unet = mean_ddim / unet_calls
|
|
||||||
|
|
||||||
# VAE bandwidth
|
|
||||||
mean_enc_time = np.mean([run["2_VAE_Encode"] for run in all_runs_stages])
|
|
||||||
mean_dec_time = np.mean([run["6_VAE_Decode"] for run in all_runs_stages])
|
|
||||||
|
|
||||||
bw = all_bw[-1] # use last run's byte counts
|
|
||||||
enc_bytes = bw["vae_enc_input_bytes"] + bw["vae_enc_output_bytes"]
|
|
||||||
dec_bytes = bw["vae_dec_input_bytes"] + bw["vae_dec_output_bytes"]
|
|
||||||
enc_bw = enc_bytes / (mean_enc_time / 1000) / 1e9 if mean_enc_time > 0 else 0
|
|
||||||
dec_bw = dec_bytes / (mean_dec_time / 1000) / 1e9 if mean_dec_time > 0 else 0
|
|
||||||
|
|
||||||
print("=" * 60)
|
|
||||||
print("TABLE 4: THROUGHPUT")
|
|
||||||
print("=" * 60)
|
|
||||||
print(f" Total pipeline latency: {mean_total:.1f} ms")
|
|
||||||
print(f" DDIM loop latency: {mean_ddim:.1f} ms")
|
|
||||||
print(f" DDIM steps: {ddim_steps}")
|
|
||||||
print(f" CFG scale: {cfg_scale} ({'2x UNet/step' if cfg_scale != 1.0 else '1x UNet/step'})")
|
|
||||||
print(f" UNet forward calls: {unet_calls}")
|
|
||||||
print(f" Per DDIM step: {per_step:.1f} ms")
|
|
||||||
print(f" Per UNet forward: {per_unet:.1f} ms")
|
|
||||||
print(f" VAE encode bandwidth: {enc_bw:.1f} GB/s (peak HBM: {MEM_BW_GBS} GB/s)")
|
|
||||||
print(f" VAE decode bandwidth: {dec_bw:.1f} GB/s (peak HBM: {MEM_BW_GBS} GB/s)")
|
|
||||||
print(f" GPU BF16 peak: {PEAK_BF16_TFLOPS} TFLOPS")
|
|
||||||
print()
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Main
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
def main():
|
|
||||||
patch_norm_bypass_autocast()
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
description="Profile the full inference pipeline")
|
|
||||||
parser.add_argument("--ckpt_path", type=str, required=True)
|
|
||||||
parser.add_argument("--config", type=str, required=True)
|
|
||||||
parser.add_argument("--ddim_steps", type=int, default=50)
|
|
||||||
parser.add_argument("--cfg_scale", type=float, default=1.0)
|
|
||||||
parser.add_argument("--n_runs", type=int, default=3)
|
|
||||||
parser.add_argument("--warmup", type=int, default=1)
|
|
||||||
parser.add_argument("--detailed", action="store_true",
|
|
||||||
help="Print per-instance UNet sub-module detail")
|
|
||||||
parser.add_argument("--deep", action="store_true",
|
|
||||||
help="Enable deep DDIM analysis: per-block, attn vs ff")
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
noise_shape = [1, 4, 16, 40, 64]
|
|
||||||
|
|
||||||
# --- Load model ---
|
|
||||||
print("Loading model...")
|
|
||||||
model = load_model(args)
|
|
||||||
observation, prompts = build_dummy_inputs(model, noise_shape)
|
|
||||||
|
|
||||||
# --- Setup hook profiler ---
|
|
||||||
unet = model.model.diffusion_model
|
|
||||||
hook_profiler = HookProfiler(unet, deep=args.deep)
|
|
||||||
hook_profiler.register()
|
|
||||||
print(f"Registered hooks on {len(hook_profiler.handles)} sub-modules")
|
|
||||||
|
|
||||||
# --- Warmup ---
|
|
||||||
print(f"Warmup: {args.warmup} run(s)...")
|
|
||||||
with torch.no_grad():
|
|
||||||
for i in range(args.warmup):
|
|
||||||
run_pipeline(model, observation, prompts, noise_shape,
|
|
||||||
args.ddim_steps, args.cfg_scale, hook_profiler)
|
|
||||||
print(f" warmup {i+1}/{args.warmup} done")
|
|
||||||
|
|
||||||
# --- Measurement runs ---
|
|
||||||
print(f"Measuring: {args.n_runs} run(s)...")
|
|
||||||
torch.cuda.reset_peak_memory_stats()
|
|
||||||
mem_before = torch.cuda.memory_allocated()
|
|
||||||
|
|
||||||
all_stages = []
|
|
||||||
all_hooks = []
|
|
||||||
all_instances = []
|
|
||||||
all_blocks = []
|
|
||||||
all_bw = []
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
for i in range(args.n_runs):
|
|
||||||
stage_times, hook_by_type, hook_by_instance, hook_by_block, bw = run_pipeline(
|
|
||||||
model, observation, prompts, noise_shape,
|
|
||||||
args.ddim_steps, args.cfg_scale, hook_profiler)
|
|
||||||
all_stages.append(stage_times)
|
|
||||||
all_hooks.append(hook_by_type)
|
|
||||||
all_instances.append(hook_by_instance)
|
|
||||||
all_blocks.append(hook_by_block)
|
|
||||||
all_bw.append(bw)
|
|
||||||
total = sum(stage_times.values())
|
|
||||||
print(f" run {i+1}/{args.n_runs}: {total:.1f} ms total")
|
|
||||||
|
|
||||||
mem_peak = torch.cuda.max_memory_allocated()
|
|
||||||
|
|
||||||
# --- Reports ---
|
|
||||||
print_stage_timing(all_stages)
|
|
||||||
print_unet_breakdown(all_hooks)
|
|
||||||
print_block_timing(all_blocks)
|
|
||||||
if args.deep:
|
|
||||||
print_attn_ff_breakdown(all_hooks)
|
|
||||||
if args.detailed:
|
|
||||||
print_unet_detailed(all_instances)
|
|
||||||
print_memory_summary(mem_before, mem_peak)
|
|
||||||
print_throughput(all_stages, all_bw, args.ddim_steps, args.cfg_scale)
|
|
||||||
|
|
||||||
hook_profiler.remove()
|
|
||||||
print("Done.")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@@ -1,287 +0,0 @@
|
|||||||
"""
|
|
||||||
Profile one DDIM sampling iteration to capture all matmul/attention ops,
|
|
||||||
their matrix sizes, wall time, and compute utilization.
|
|
||||||
|
|
||||||
Uses torch.profiler for CUDA timing and FlopCounterMode for accurate
|
|
||||||
FLOPS counting (works on ROCm where Tensile kernels don't report FLOPS).
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 CUDA_VISIBLE_DEVICES=0 \
|
|
||||||
python scripts/evaluation/profile_unet.py \
|
|
||||||
--ckpt_path ckpts/unifolm_wma_dual_mix_bf16.ckpt \
|
|
||||||
--config configs/inference/world_model_interaction.yaml
|
|
||||||
"""
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from collections import OrderedDict, defaultdict
|
|
||||||
from omegaconf import OmegaConf
|
|
||||||
from torch.utils.flop_counter import FlopCounterMode
|
|
||||||
from unifolm_wma.utils.utils import instantiate_from_config
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
|
|
||||||
def patch_norm_bypass_autocast():
|
|
||||||
"""Monkey-patch GroupNorm and LayerNorm to bypass autocast's fp32 policy."""
|
|
||||||
|
|
||||||
def _group_norm_forward(self, x):
|
|
||||||
with torch.amp.autocast('cuda', enabled=False):
|
|
||||||
return F.group_norm(
|
|
||||||
x, self.num_groups,
|
|
||||||
self.weight.to(x.dtype) if self.weight is not None else None,
|
|
||||||
self.bias.to(x.dtype) if self.bias is not None else None,
|
|
||||||
self.eps)
|
|
||||||
|
|
||||||
def _layer_norm_forward(self, x):
|
|
||||||
with torch.amp.autocast('cuda', enabled=False):
|
|
||||||
return F.layer_norm(
|
|
||||||
x, self.normalized_shape,
|
|
||||||
self.weight.to(x.dtype) if self.weight is not None else None,
|
|
||||||
self.bias.to(x.dtype) if self.bias is not None else None,
|
|
||||||
self.eps)
|
|
||||||
|
|
||||||
torch.nn.GroupNorm.forward = _group_norm_forward
|
|
||||||
torch.nn.LayerNorm.forward = _layer_norm_forward
|
|
||||||
|
|
||||||
|
|
||||||
# --- W7900D theoretical peak (TFLOPS) ---
|
|
||||||
PEAK_BF16_TFLOPS = 61.0
|
|
||||||
PEAK_FP32_TFLOPS = 30.5
|
|
||||||
|
|
||||||
|
|
||||||
def apply_torch_compile(model, hot_indices=(5, 8, 9)):
|
|
||||||
"""Compile ResBlock._forward in the hottest output_blocks for operator fusion."""
|
|
||||||
from unifolm_wma.modules.networks.wma_model import ResBlock
|
|
||||||
unet = model.model.diffusion_model
|
|
||||||
compiled = 0
|
|
||||||
for idx in hot_indices:
|
|
||||||
block = unet.output_blocks[idx]
|
|
||||||
for layer in block:
|
|
||||||
if isinstance(layer, ResBlock):
|
|
||||||
layer._forward = torch.compile(layer._forward, mode="default")
|
|
||||||
compiled += 1
|
|
||||||
print(f" ✓ torch.compile: {compiled} ResBlocks in output_blocks{list(hot_indices)}")
|
|
||||||
|
|
||||||
|
|
||||||
def load_model(args):
|
|
||||||
config = OmegaConf.load(args.config)
|
|
||||||
config['model']['params']['wma_config']['params']['use_checkpoint'] = False
|
|
||||||
model = instantiate_from_config(config.model)
|
|
||||||
|
|
||||||
state_dict = torch.load(args.ckpt_path, map_location="cpu")
|
|
||||||
if "state_dict" in state_dict:
|
|
||||||
state_dict = state_dict["state_dict"]
|
|
||||||
model.load_state_dict(state_dict, strict=True)
|
|
||||||
|
|
||||||
model.eval()
|
|
||||||
model.model.to(torch.bfloat16)
|
|
||||||
apply_torch_compile(model)
|
|
||||||
model = model.cuda()
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def build_call_kwargs(model, noise_shape):
|
|
||||||
"""Build dummy inputs matching the hybrid conditioning forward signature."""
|
|
||||||
device = next(model.parameters()).device
|
|
||||||
B, C, T, H, W = noise_shape # [1, 4, 16, 40, 64]
|
|
||||||
dtype = torch.bfloat16
|
|
||||||
|
|
||||||
x_action = torch.randn(B, 16, 16, device=device, dtype=dtype)
|
|
||||||
x_state = torch.randn(B, 16, 16, device=device, dtype=dtype)
|
|
||||||
timesteps = torch.tensor([500], device=device, dtype=torch.long)
|
|
||||||
context = torch.randn(B, 351, 1024, device=device, dtype=dtype)
|
|
||||||
|
|
||||||
obs_images = torch.randn(B, 3, 2, 320, 512, device=device, dtype=dtype)
|
|
||||||
obs_state = torch.randn(B, 2, 16, device=device, dtype=dtype)
|
|
||||||
context_action = [obs_images, obs_state, True, False]
|
|
||||||
fps = torch.tensor([1], device=device, dtype=torch.long)
|
|
||||||
|
|
||||||
x_raw = torch.randn(B, C, T, H, W, device=device, dtype=dtype)
|
|
||||||
c_concat = [torch.randn(B, C, T, H, W, device=device, dtype=dtype)]
|
|
||||||
|
|
||||||
return dict(
|
|
||||||
x=x_raw, x_action=x_action, x_state=x_state, t=timesteps,
|
|
||||||
c_concat=c_concat, c_crossattn=[context],
|
|
||||||
c_crossattn_action=context_action, s=fps,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def profile_one_step(model, noise_shape):
|
|
||||||
"""Run one UNet forward pass under torch.profiler for CUDA timing."""
|
|
||||||
diff_wrapper = model.model
|
|
||||||
call_kwargs = build_call_kwargs(model, noise_shape)
|
|
||||||
|
|
||||||
with torch.no_grad(), torch.autocast('cuda', dtype=torch.bfloat16):
|
|
||||||
# Warmup
|
|
||||||
for _ in range(2):
|
|
||||||
_ = diff_wrapper(**call_kwargs)
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
|
|
||||||
with torch.profiler.profile(
|
|
||||||
activities=[torch.profiler.ProfilerActivity.CUDA],
|
|
||||||
record_shapes=True,
|
|
||||||
with_flops=True,
|
|
||||||
) as prof:
|
|
||||||
_ = diff_wrapper(**call_kwargs)
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
|
|
||||||
return prof
|
|
||||||
|
|
||||||
|
|
||||||
def count_flops(model, noise_shape):
|
|
||||||
"""Run one UNet forward pass under FlopCounterMode for accurate FLOPS."""
|
|
||||||
diff_wrapper = model.model
|
|
||||||
call_kwargs = build_call_kwargs(model, noise_shape)
|
|
||||||
|
|
||||||
with torch.no_grad(), torch.autocast('cuda', dtype=torch.bfloat16):
|
|
||||||
flop_counter = FlopCounterMode(display=False)
|
|
||||||
with flop_counter:
|
|
||||||
_ = diff_wrapper(**call_kwargs)
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
|
|
||||||
return flop_counter
|
|
||||||
|
|
||||||
|
|
||||||
def print_report(prof, flop_counter):
|
|
||||||
"""Parse profiler results and print a structured report with accurate FLOPS."""
|
|
||||||
events = prof.key_averages()
|
|
||||||
|
|
||||||
# --- Extract per-operator FLOPS from FlopCounterMode ---
|
|
||||||
# flop_counts is {module_name: {op_name: count}}; use only "Global" to avoid double-counting
|
|
||||||
flop_by_op = {}
|
|
||||||
flop_by_module = {}
|
|
||||||
if hasattr(flop_counter, 'flop_counts'):
|
|
||||||
# Per-op: only from top-level "Global" entry (no parent/child duplication)
|
|
||||||
global_ops = flop_counter.flop_counts.get("Global", {})
|
|
||||||
for op_name, flop_count in global_ops.items():
|
|
||||||
key = str(op_name).split('.')[-1]
|
|
||||||
flop_by_op[key] = flop_by_op.get(key, 0) + flop_count
|
|
||||||
|
|
||||||
# Per-module: collect all, skip "Global" and top-level wrapper duplicates
|
|
||||||
for module_name, op_dict in flop_counter.flop_counts.items():
|
|
||||||
module_total = sum(op_dict.values())
|
|
||||||
if module_total > 0:
|
|
||||||
flop_by_module[module_name] = module_total
|
|
||||||
|
|
||||||
total_counted_flops = flop_counter.get_total_flops()
|
|
||||||
|
|
||||||
# Collect matmul-like ops
|
|
||||||
matmul_ops = []
|
|
||||||
other_ops = []
|
|
||||||
|
|
||||||
for evt in events:
|
|
||||||
if evt.device_time_total <= 0:
|
|
||||||
continue
|
|
||||||
name = evt.key
|
|
||||||
is_matmul = any(k in name.lower() for k in
|
|
||||||
['mm', 'gemm', 'addmm', 'bmm', 'einsum', 'dot', 'linear'])
|
|
||||||
entry = {
|
|
||||||
'name': name,
|
|
||||||
'input_shapes': str(evt.input_shapes) if evt.input_shapes else '',
|
|
||||||
'cuda_time_ms': evt.device_time_total / 1000.0,
|
|
||||||
'count': evt.count,
|
|
||||||
'flops': evt.flops if evt.flops else 0,
|
|
||||||
}
|
|
||||||
if is_matmul:
|
|
||||||
matmul_ops.append(entry)
|
|
||||||
else:
|
|
||||||
other_ops.append(entry)
|
|
||||||
|
|
||||||
# Sort by CUDA time
|
|
||||||
matmul_ops.sort(key=lambda x: x['cuda_time_ms'], reverse=True)
|
|
||||||
other_ops.sort(key=lambda x: x['cuda_time_ms'], reverse=True)
|
|
||||||
|
|
||||||
total_cuda_ms = sum(e['cuda_time_ms'] for e in matmul_ops + other_ops)
|
|
||||||
total_matmul_ms = sum(e['cuda_time_ms'] for e in matmul_ops)
|
|
||||||
# --- Print matmul ops ---
|
|
||||||
print("=" * 130)
|
|
||||||
print("MATMUL / LINEAR OPS (sorted by CUDA time)")
|
|
||||||
print("=" * 130)
|
|
||||||
print(f"{'Op':>35} | {'Count':>5} | {'CUDA(ms)':>10} | Shapes")
|
|
||||||
print("-" * 130)
|
|
||||||
|
|
||||||
for op in matmul_ops:
|
|
||||||
shapes_str = op['input_shapes'][:60] if op['input_shapes'] else ''
|
|
||||||
print(f"{op['name']:>35} | {op['count']:>5} | {op['cuda_time_ms']:>10.3f} | {shapes_str}")
|
|
||||||
|
|
||||||
# --- Print top non-matmul ops ---
|
|
||||||
print()
|
|
||||||
print("=" * 130)
|
|
||||||
print("TOP NON-MATMUL OPS (sorted by CUDA time)")
|
|
||||||
print("=" * 130)
|
|
||||||
print(f"{'Op':>40} | {'Count':>5} | {'CUDA(ms)':>10} | Shapes")
|
|
||||||
print("-" * 130)
|
|
||||||
|
|
||||||
for op in other_ops[:20]:
|
|
||||||
shapes_str = op['input_shapes'][:60] if op['input_shapes'] else ''
|
|
||||||
print(f"{op['name']:>40} | {op['count']:>5} | {op['cuda_time_ms']:>10.3f} | {shapes_str}")
|
|
||||||
|
|
||||||
# --- FlopCounterMode per-operator breakdown ---
|
|
||||||
print()
|
|
||||||
print("=" * 130)
|
|
||||||
print("FLOPS BY ATen OPERATOR (FlopCounterMode)")
|
|
||||||
print("=" * 130)
|
|
||||||
print(f"{'ATen Op':>25} | {'GFLOPS':>12} | {'% of Total':>10}")
|
|
||||||
print("-" * 55)
|
|
||||||
|
|
||||||
sorted_flop_ops = sorted(flop_by_op.items(), key=lambda x: x[1], reverse=True)
|
|
||||||
for op_name, flops in sorted_flop_ops:
|
|
||||||
gflops = flops / 1e9
|
|
||||||
pct = flops / total_counted_flops * 100 if total_counted_flops > 0 else 0
|
|
||||||
print(f"{op_name:>25} | {gflops:>12.2f} | {pct:>9.1f}%")
|
|
||||||
|
|
||||||
# --- FlopCounterMode per-module breakdown ---
|
|
||||||
if flop_by_module:
|
|
||||||
print()
|
|
||||||
print("=" * 130)
|
|
||||||
print("FLOPS BY MODULE (FlopCounterMode)")
|
|
||||||
print("=" * 130)
|
|
||||||
print(f"{'Module':>60} | {'GFLOPS':>12} | {'% of Total':>10}")
|
|
||||||
print("-" * 90)
|
|
||||||
|
|
||||||
sorted_modules = sorted(flop_by_module.items(), key=lambda x: x[1], reverse=True)
|
|
||||||
for mod_name, flops in sorted_modules[:30]:
|
|
||||||
gflops = flops / 1e9
|
|
||||||
pct = flops / total_counted_flops * 100 if total_counted_flops > 0 else 0
|
|
||||||
name_str = mod_name[-57:] if len(mod_name) > 57 else mod_name
|
|
||||||
print(f"{name_str:>60} | {gflops:>12.2f} | {pct:>9.1f}%")
|
|
||||||
|
|
||||||
# --- Summary ---
|
|
||||||
print()
|
|
||||||
print("=" * 130)
|
|
||||||
print("SUMMARY")
|
|
||||||
print("=" * 130)
|
|
||||||
print(f" Total CUDA time: {total_cuda_ms:.1f} ms")
|
|
||||||
print(f" Matmul CUDA time: {total_matmul_ms:.1f} ms ({total_matmul_ms/total_cuda_ms*100:.1f}%)")
|
|
||||||
print(f" Non-matmul CUDA time: {total_cuda_ms - total_matmul_ms:.1f} ms ({(total_cuda_ms-total_matmul_ms)/total_cuda_ms*100:.1f}%)")
|
|
||||||
print(f" Total FLOPS (FlopCounter): {total_counted_flops/1e9:.2f} GFLOPS")
|
|
||||||
if total_matmul_ms > 0 and total_counted_flops > 0:
|
|
||||||
avg_tflops = total_counted_flops / (total_matmul_ms / 1000.0) / 1e12
|
|
||||||
avg_util = avg_tflops / PEAK_BF16_TFLOPS * 100
|
|
||||||
overall_tflops = total_counted_flops / (total_cuda_ms / 1000.0) / 1e12
|
|
||||||
overall_util = overall_tflops / PEAK_BF16_TFLOPS * 100
|
|
||||||
print(f" Matmul throughput: {avg_tflops:.2f} TFLOPS/s ({avg_util:.1f}% of BF16 peak)")
|
|
||||||
print(f" Overall throughput: {overall_tflops:.2f} TFLOPS/s ({overall_util:.1f}% of BF16 peak)")
|
|
||||||
print(f" GPU peak (BF16): {PEAK_BF16_TFLOPS} TFLOPS")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
patch_norm_bypass_autocast()
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument("--ckpt_path", type=str, required=True)
|
|
||||||
parser.add_argument("--config", type=str, required=True)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
print("Loading model...")
|
|
||||||
model = load_model(args)
|
|
||||||
|
|
||||||
noise_shape = [1, 4, 16, 40, 64]
|
|
||||||
|
|
||||||
print(f"Profiling UNet forward pass with shape {noise_shape}...")
|
|
||||||
prof = profile_one_step(model, noise_shape)
|
|
||||||
|
|
||||||
print("Counting FLOPS with FlopCounterMode...")
|
|
||||||
flop_counter = count_flops(model, noise_shape)
|
|
||||||
|
|
||||||
print_report(prof, flop_counter)
|
|
||||||
@@ -19,6 +19,9 @@ from fastapi.responses import JSONResponse
|
|||||||
from typing import Any, Dict, Optional, Tuple, List
|
from typing import Any, Dict, Optional, Tuple, List
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
torch.backends.cudnn.allow_tf32 = True
|
||||||
|
|
||||||
from unifolm_wma.utils.utils import instantiate_from_config
|
from unifolm_wma.utils.utils import instantiate_from_config
|
||||||
from unifolm_wma.models.samplers.ddim import DDIMSampler
|
from unifolm_wma.models.samplers.ddim import DDIMSampler
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
import argparse, os, glob
|
import argparse, os, glob
|
||||||
from contextlib import nullcontext
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import random
|
import random
|
||||||
import torch
|
import torch
|
||||||
@@ -10,6 +9,9 @@ import logging
|
|||||||
import einops
|
import einops
|
||||||
import warnings
|
import warnings
|
||||||
import imageio
|
import imageio
|
||||||
|
import atexit
|
||||||
|
import multiprocessing as mp
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
|
||||||
from pytorch_lightning import seed_everything
|
from pytorch_lightning import seed_everything
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
@@ -17,39 +19,18 @@ from tqdm import tqdm
|
|||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from eval_utils import populate_queues, log_to_tensorboard
|
from eval_utils import populate_queues
|
||||||
from collections import deque
|
from collections import deque
|
||||||
|
from typing import Optional, List, Any
|
||||||
|
|
||||||
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
torch.backends.cudnn.allow_tf32 = True
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from unifolm_wma.models.samplers.ddim import DDIMSampler
|
from unifolm_wma.models.samplers.ddim import DDIMSampler
|
||||||
from unifolm_wma.utils.utils import instantiate_from_config
|
from unifolm_wma.utils.utils import instantiate_from_config
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
|
|
||||||
def patch_norm_bypass_autocast():
|
|
||||||
"""Monkey-patch GroupNorm and LayerNorm to bypass autocast's fp32 policy.
|
|
||||||
This eliminates bf16->fp32->bf16 dtype conversions during UNet forward."""
|
|
||||||
|
|
||||||
def _group_norm_forward(self, x):
|
|
||||||
with torch.amp.autocast('cuda', enabled=False):
|
|
||||||
return F.group_norm(
|
|
||||||
x, self.num_groups,
|
|
||||||
self.weight.to(x.dtype) if self.weight is not None else None,
|
|
||||||
self.bias.to(x.dtype) if self.bias is not None else None,
|
|
||||||
self.eps)
|
|
||||||
|
|
||||||
def _layer_norm_forward(self, x):
|
|
||||||
with torch.amp.autocast('cuda', enabled=False):
|
|
||||||
return F.layer_norm(
|
|
||||||
x, self.normalized_shape,
|
|
||||||
self.weight.to(x.dtype) if self.weight is not None else None,
|
|
||||||
self.bias.to(x.dtype) if self.bias is not None else None,
|
|
||||||
self.eps)
|
|
||||||
|
|
||||||
torch.nn.GroupNorm.forward = _group_norm_forward
|
|
||||||
torch.nn.LayerNorm.forward = _layer_norm_forward
|
|
||||||
|
|
||||||
|
|
||||||
def get_device_from_parameters(module: nn.Module) -> torch.device:
|
def get_device_from_parameters(module: nn.Module) -> torch.device:
|
||||||
@@ -64,92 +45,6 @@ def get_device_from_parameters(module: nn.Module) -> torch.device:
|
|||||||
return next(iter(module.parameters())).device
|
return next(iter(module.parameters())).device
|
||||||
|
|
||||||
|
|
||||||
def apply_precision_settings(model: nn.Module, args: argparse.Namespace) -> nn.Module:
|
|
||||||
"""Apply precision settings to model components based on command-line arguments.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model (nn.Module): The model to apply precision settings to.
|
|
||||||
args (argparse.Namespace): Parsed command-line arguments containing precision settings.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
nn.Module: Model with precision settings applied.
|
|
||||||
"""
|
|
||||||
print(f">>> Applying precision settings:")
|
|
||||||
print(f" - Diffusion dtype: {args.diffusion_dtype}")
|
|
||||||
print(f" - Projector mode: {args.projector_mode}")
|
|
||||||
print(f" - Encoder mode: {args.encoder_mode}")
|
|
||||||
print(f" - VAE dtype: {args.vae_dtype}")
|
|
||||||
|
|
||||||
# 1. Set Diffusion backbone precision
|
|
||||||
if args.diffusion_dtype == "bf16":
|
|
||||||
# Convert diffusion model weights to bf16
|
|
||||||
model.model.to(torch.bfloat16)
|
|
||||||
model.diffusion_autocast_dtype = torch.bfloat16
|
|
||||||
print(" ✓ Diffusion model weights converted to bfloat16")
|
|
||||||
else:
|
|
||||||
model.diffusion_autocast_dtype = torch.bfloat16
|
|
||||||
print(" ✓ Diffusion model using fp32")
|
|
||||||
|
|
||||||
# 2. Set Projector precision
|
|
||||||
if args.projector_mode == "bf16_full":
|
|
||||||
model.state_projector.to(torch.bfloat16)
|
|
||||||
model.action_projector.to(torch.bfloat16)
|
|
||||||
model.projector_autocast_dtype = None
|
|
||||||
print(" ✓ Projectors converted to bfloat16")
|
|
||||||
elif args.projector_mode == "autocast":
|
|
||||||
model.projector_autocast_dtype = torch.bfloat16
|
|
||||||
print(" ✓ Projectors will use autocast (weights fp32, compute bf16)")
|
|
||||||
else:
|
|
||||||
model.projector_autocast_dtype = None
|
|
||||||
# fp32 mode: do nothing, keep original precision
|
|
||||||
|
|
||||||
# 3. Set Encoder precision
|
|
||||||
if args.encoder_mode == "bf16_full":
|
|
||||||
model.embedder.to(torch.bfloat16)
|
|
||||||
model.image_proj_model.to(torch.bfloat16)
|
|
||||||
model.encoder_autocast_dtype = None
|
|
||||||
print(" ✓ Encoders converted to bfloat16")
|
|
||||||
elif args.encoder_mode == "autocast":
|
|
||||||
model.encoder_autocast_dtype = torch.bfloat16
|
|
||||||
print(" ✓ Encoders will use autocast (weights fp32, compute bf16)")
|
|
||||||
else:
|
|
||||||
model.encoder_autocast_dtype = None
|
|
||||||
# fp32 mode: do nothing, keep original precision
|
|
||||||
|
|
||||||
# 4. Set VAE precision
|
|
||||||
if args.vae_dtype == "bf16":
|
|
||||||
model.first_stage_model.to(torch.bfloat16)
|
|
||||||
print(" ✓ VAE converted to bfloat16")
|
|
||||||
else:
|
|
||||||
print(" ✓ VAE kept in fp32 for best quality")
|
|
||||||
|
|
||||||
# 5. Safety net: ensure no fp32 parameters remain when all components are bf16
|
|
||||||
if args.diffusion_dtype == "bf16":
|
|
||||||
fp32_params = [(n, p) for n, p in model.named_parameters() if p.dtype == torch.float32]
|
|
||||||
if fp32_params:
|
|
||||||
print(f" ⚠ Found {len(fp32_params)} fp32 params, converting to bf16")
|
|
||||||
for name, param in fp32_params:
|
|
||||||
param.data = param.data.to(torch.bfloat16)
|
|
||||||
print(" ✓ All parameters converted to bfloat16")
|
|
||||||
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def apply_torch_compile(model, hot_indices=(5, 8, 9)):
|
|
||||||
"""Compile ResBlock._forward in the hottest output_blocks for operator fusion."""
|
|
||||||
from unifolm_wma.modules.networks.wma_model import ResBlock
|
|
||||||
unet = model.model.diffusion_model
|
|
||||||
compiled = 0
|
|
||||||
for idx in hot_indices:
|
|
||||||
block = unet.output_blocks[idx]
|
|
||||||
for layer in block:
|
|
||||||
if isinstance(layer, ResBlock):
|
|
||||||
layer._forward = torch.compile(layer._forward, mode="default")
|
|
||||||
compiled += 1
|
|
||||||
print(f" ✓ torch.compile: {compiled} ResBlocks in output_blocks{list(hot_indices)}")
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def write_video(video_path: str, stacked_frames: list, fps: int) -> None:
|
def write_video(video_path: str, stacked_frames: list, fps: int) -> None:
|
||||||
"""Save a list of frames to a video file.
|
"""Save a list of frames to a video file.
|
||||||
|
|
||||||
@@ -262,6 +157,107 @@ def save_results(video: Tensor, filename: str, fps: int = 8) -> None:
|
|||||||
options={'crf': '10'})
|
options={'crf': '10'})
|
||||||
|
|
||||||
|
|
||||||
|
# ========== Async I/O ==========
|
||||||
|
_io_executor: Optional[ThreadPoolExecutor] = None
|
||||||
|
_io_futures: List[Any] = []
|
||||||
|
|
||||||
|
|
||||||
|
def _get_io_executor() -> ThreadPoolExecutor:
|
||||||
|
global _io_executor
|
||||||
|
if _io_executor is None:
|
||||||
|
_io_executor = ThreadPoolExecutor(max_workers=2)
|
||||||
|
return _io_executor
|
||||||
|
|
||||||
|
|
||||||
|
def _flush_io():
|
||||||
|
"""Wait for all pending async I/O to finish."""
|
||||||
|
global _io_futures
|
||||||
|
for fut in _io_futures:
|
||||||
|
try:
|
||||||
|
fut.result()
|
||||||
|
except Exception as e:
|
||||||
|
print(f">>> [async I/O] error: {e}")
|
||||||
|
_io_futures.clear()
|
||||||
|
|
||||||
|
|
||||||
|
atexit.register(_flush_io)
|
||||||
|
|
||||||
|
|
||||||
|
def _save_results_sync(video_cpu: Tensor, filename: str, fps: int) -> None:
|
||||||
|
"""Synchronous save on CPU tensor (runs in background thread)."""
|
||||||
|
video = torch.clamp(video_cpu.float(), -1., 1.)
|
||||||
|
n = video.shape[0]
|
||||||
|
video = video.permute(2, 0, 1, 3, 4)
|
||||||
|
frame_grids = [
|
||||||
|
torchvision.utils.make_grid(framesheet, nrow=int(n), padding=0)
|
||||||
|
for framesheet in video
|
||||||
|
]
|
||||||
|
grid = torch.stack(frame_grids, dim=0)
|
||||||
|
grid = (grid + 1.0) / 2.0
|
||||||
|
grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1)
|
||||||
|
torchvision.io.write_video(filename,
|
||||||
|
grid,
|
||||||
|
fps=fps,
|
||||||
|
video_codec='h264',
|
||||||
|
options={'crf': '10'})
|
||||||
|
|
||||||
|
|
||||||
|
def save_results_async(video: Tensor, filename: str, fps: int = 8) -> None:
|
||||||
|
"""Submit video saving to background thread pool."""
|
||||||
|
video_cpu = video.detach().cpu()
|
||||||
|
fut = _get_io_executor().submit(_save_results_sync, video_cpu, filename, fps)
|
||||||
|
_io_futures.append(fut)
|
||||||
|
|
||||||
|
|
||||||
|
def _log_to_tb_sync(writer, video_cpu: Tensor, tag: str, fps: int) -> None:
|
||||||
|
"""Synchronous TensorBoard log on CPU tensor (runs in background thread)."""
|
||||||
|
if video_cpu.dim() == 5:
|
||||||
|
n = video_cpu.shape[0]
|
||||||
|
video = video_cpu.permute(2, 0, 1, 3, 4)
|
||||||
|
frame_grids = [
|
||||||
|
torchvision.utils.make_grid(framesheet, nrow=int(n), padding=0)
|
||||||
|
for framesheet in video
|
||||||
|
]
|
||||||
|
grid = torch.stack(frame_grids, dim=0)
|
||||||
|
grid = (grid + 1.0) / 2.0
|
||||||
|
grid = grid.unsqueeze(dim=0)
|
||||||
|
writer.add_video(tag, grid, fps=fps)
|
||||||
|
|
||||||
|
|
||||||
|
def log_to_tensorboard_async(writer, data: Tensor, tag: str, fps: int = 10) -> None:
|
||||||
|
"""Submit TensorBoard logging to background thread pool."""
|
||||||
|
if isinstance(data, torch.Tensor) and data.dim() == 5:
|
||||||
|
data_cpu = data.detach().cpu()
|
||||||
|
fut = _get_io_executor().submit(_log_to_tb_sync, writer, data_cpu, tag, fps)
|
||||||
|
_io_futures.append(fut)
|
||||||
|
|
||||||
|
|
||||||
|
def _video_tensor_to_frames(video: Tensor) -> np.ndarray:
|
||||||
|
video = torch.clamp(video.float(), -1., 1.)
|
||||||
|
n = video.shape[0]
|
||||||
|
video = video.permute(2, 0, 1, 3, 4)
|
||||||
|
frame_grids = [
|
||||||
|
torchvision.utils.make_grid(f, nrow=int(n), padding=0) for f in video
|
||||||
|
]
|
||||||
|
grid = torch.stack(frame_grids, dim=0)
|
||||||
|
grid = ((grid + 1.0) / 2.0 * 255).to(torch.uint8).permute(0, 2, 3, 1)
|
||||||
|
return grid.numpy()[:, :, :, ::-1]
|
||||||
|
|
||||||
|
|
||||||
|
def _video_writer_process(q: mp.Queue, filename: str, fps: int):
|
||||||
|
frames = []
|
||||||
|
while True:
|
||||||
|
item = q.get()
|
||||||
|
if item is None:
|
||||||
|
break
|
||||||
|
frames.append(_video_tensor_to_frames(item))
|
||||||
|
if frames:
|
||||||
|
grid = np.concatenate(frames, axis=0)
|
||||||
|
grid = torch.from_numpy(grid[:, :, :, ::-1].copy()) # BGR → RGB
|
||||||
|
torchvision.io.write_video(filename, grid, fps=fps,
|
||||||
|
video_codec='h264', options={'crf': '10'})
|
||||||
|
|
||||||
|
|
||||||
def get_init_frame_path(data_dir: str, sample: dict) -> str:
|
def get_init_frame_path(data_dir: str, sample: dict) -> str:
|
||||||
"""Construct the init_frame path from directory and sample metadata.
|
"""Construct the init_frame path from directory and sample metadata.
|
||||||
|
|
||||||
@@ -374,11 +370,6 @@ def get_latent_z(model, videos: Tensor) -> Tensor:
|
|||||||
"""
|
"""
|
||||||
b, c, t, h, w = videos.shape
|
b, c, t, h, w = videos.shape
|
||||||
x = rearrange(videos, 'b c t h w -> (b t) c h w')
|
x = rearrange(videos, 'b c t h w -> (b t) c h w')
|
||||||
|
|
||||||
# Auto-detect VAE dtype and convert input
|
|
||||||
vae_dtype = next(model.first_stage_model.parameters()).dtype
|
|
||||||
x = x.to(dtype=vae_dtype)
|
|
||||||
|
|
||||||
z = model.encode_first_stage(x)
|
z = model.encode_first_stage(x)
|
||||||
z = rearrange(z, '(b t) c h w -> b c t h w', b=b, t=t)
|
z = rearrange(z, '(b t) c h w -> b c t h w', b=b, t=t)
|
||||||
return z
|
return z
|
||||||
@@ -484,20 +475,9 @@ def image_guided_synthesis_sim_mode(
|
|||||||
|
|
||||||
fs = torch.tensor([fs] * batch_size, dtype=torch.long, device=model.device)
|
fs = torch.tensor([fs] * batch_size, dtype=torch.long, device=model.device)
|
||||||
|
|
||||||
# Auto-detect model dtype and convert inputs accordingly
|
|
||||||
model_dtype = next(model.embedder.parameters()).dtype
|
|
||||||
|
|
||||||
img = observation['observation.images.top'].permute(0, 2, 1, 3, 4)
|
img = observation['observation.images.top'].permute(0, 2, 1, 3, 4)
|
||||||
cond_img = rearrange(img, 'b o c h w -> (b o) c h w')[-1:].to(dtype=model_dtype)
|
cond_img = rearrange(img, 'b o c h w -> (b o) c h w')[-1:]
|
||||||
|
with torch.cuda.amp.autocast(dtype=torch.float16):
|
||||||
# Encoder autocast: weights stay fp32, compute in bf16
|
|
||||||
enc_ac_dtype = getattr(model, 'encoder_autocast_dtype', None)
|
|
||||||
if enc_ac_dtype is not None and model.device.type == 'cuda':
|
|
||||||
enc_ctx = torch.autocast('cuda', dtype=enc_ac_dtype)
|
|
||||||
else:
|
|
||||||
enc_ctx = nullcontext()
|
|
||||||
|
|
||||||
with enc_ctx:
|
|
||||||
cond_img_emb = model.embedder(cond_img)
|
cond_img_emb = model.embedder(cond_img)
|
||||||
cond_img_emb = model.image_proj_model(cond_img_emb)
|
cond_img_emb = model.image_proj_model(cond_img_emb)
|
||||||
|
|
||||||
@@ -513,21 +493,11 @@ def image_guided_synthesis_sim_mode(
|
|||||||
prompts = [""] * batch_size
|
prompts = [""] * batch_size
|
||||||
cond_ins_emb = model.get_learned_conditioning(prompts)
|
cond_ins_emb = model.get_learned_conditioning(prompts)
|
||||||
|
|
||||||
# Auto-detect projector dtype and convert inputs
|
with torch.cuda.amp.autocast(dtype=torch.float16):
|
||||||
projector_dtype = next(model.state_projector.parameters()).dtype
|
cond_state_emb = model.state_projector(observation['observation.state'])
|
||||||
|
|
||||||
# Projector autocast: weights stay fp32, compute in bf16
|
|
||||||
proj_ac_dtype = getattr(model, 'projector_autocast_dtype', None)
|
|
||||||
if proj_ac_dtype is not None and model.device.type == 'cuda':
|
|
||||||
proj_ctx = torch.autocast('cuda', dtype=proj_ac_dtype)
|
|
||||||
else:
|
|
||||||
proj_ctx = nullcontext()
|
|
||||||
|
|
||||||
with proj_ctx:
|
|
||||||
cond_state_emb = model.state_projector(observation['observation.state'].to(dtype=projector_dtype))
|
|
||||||
cond_state_emb = cond_state_emb + model.agent_state_pos_emb
|
cond_state_emb = cond_state_emb + model.agent_state_pos_emb
|
||||||
|
|
||||||
cond_action_emb = model.action_projector(observation['action'].to(dtype=projector_dtype))
|
cond_action_emb = model.action_projector(observation['action'])
|
||||||
cond_action_emb = cond_action_emb + model.agent_action_pos_emb
|
cond_action_emb = cond_action_emb + model.agent_action_pos_emb
|
||||||
|
|
||||||
if not sim_mode:
|
if not sim_mode:
|
||||||
@@ -550,18 +520,10 @@ def image_guided_synthesis_sim_mode(
|
|||||||
kwargs.update({"unconditional_conditioning_img_nonetext": None})
|
kwargs.update({"unconditional_conditioning_img_nonetext": None})
|
||||||
cond_mask = None
|
cond_mask = None
|
||||||
cond_z0 = None
|
cond_z0 = None
|
||||||
|
|
||||||
# Setup autocast context for diffusion sampling
|
|
||||||
autocast_dtype = getattr(model, 'diffusion_autocast_dtype', None)
|
|
||||||
if autocast_dtype is not None and model.device.type == 'cuda':
|
|
||||||
autocast_ctx = torch.autocast('cuda', dtype=autocast_dtype)
|
|
||||||
else:
|
|
||||||
autocast_ctx = nullcontext()
|
|
||||||
|
|
||||||
batch_variants = None
|
batch_variants = None
|
||||||
|
samples = None
|
||||||
if ddim_sampler is not None:
|
if ddim_sampler is not None:
|
||||||
with autocast_ctx:
|
samples, actions, states, intermedia = ddim_sampler.sample(
|
||||||
samples, actions, states, intermedia = ddim_sampler.sample(
|
|
||||||
S=ddim_steps,
|
S=ddim_steps,
|
||||||
conditioning=cond,
|
conditioning=cond,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
@@ -583,7 +545,7 @@ def image_guided_synthesis_sim_mode(
|
|||||||
batch_images = model.decode_first_stage(samples)
|
batch_images = model.decode_first_stage(samples)
|
||||||
batch_variants = batch_images
|
batch_variants = batch_images
|
||||||
|
|
||||||
return batch_variants, actions, states
|
return batch_variants, actions, states, samples
|
||||||
|
|
||||||
|
|
||||||
def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||||
@@ -608,40 +570,67 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
|||||||
csv_path = os.path.join(args.prompt_dir, f"{args.dataset}.csv")
|
csv_path = os.path.join(args.prompt_dir, f"{args.dataset}.csv")
|
||||||
df = pd.read_csv(csv_path)
|
df = pd.read_csv(csv_path)
|
||||||
|
|
||||||
# Load config
|
# Load config (always needed for data setup)
|
||||||
config = OmegaConf.load(args.config)
|
config = OmegaConf.load(args.config)
|
||||||
config['model']['params']['wma_config']['params'][
|
|
||||||
'use_checkpoint'] = False
|
|
||||||
model = instantiate_from_config(config.model)
|
|
||||||
model.perframe_ae = args.perframe_ae
|
|
||||||
assert os.path.exists(args.ckpt_path), "Error: checkpoint Not Found!"
|
|
||||||
model = load_model_checkpoint(model, args.ckpt_path)
|
|
||||||
model.eval()
|
|
||||||
print(f'>>> Load pre-trained model ...')
|
|
||||||
|
|
||||||
# Apply precision settings before moving to GPU
|
prepared_path = args.ckpt_path + ".prepared.pt"
|
||||||
model = apply_precision_settings(model, args)
|
if os.path.exists(prepared_path):
|
||||||
|
# ---- Fast path: load the fully-prepared model ----
|
||||||
|
print(f">>> Loading prepared model from {prepared_path} ...")
|
||||||
|
model = torch.load(prepared_path,
|
||||||
|
map_location=f"cuda:{gpu_no}",
|
||||||
|
weights_only=False,
|
||||||
|
mmap=True)
|
||||||
|
model.eval()
|
||||||
|
print(f">>> Prepared model loaded.")
|
||||||
|
else:
|
||||||
|
# ---- Normal path: construct + load checkpoint ----
|
||||||
|
config['model']['params']['wma_config']['params'][
|
||||||
|
'use_checkpoint'] = False
|
||||||
|
model = instantiate_from_config(config.model)
|
||||||
|
model.perframe_ae = args.perframe_ae
|
||||||
|
|
||||||
# Compile hot ResBlocks for operator fusion
|
assert os.path.exists(args.ckpt_path), "Error: checkpoint Not Found!"
|
||||||
apply_torch_compile(model)
|
model = load_model_checkpoint(model, args.ckpt_path)
|
||||||
|
model.eval()
|
||||||
|
model = model.cuda(gpu_no)
|
||||||
|
print(f'>>> Load pre-trained model ...')
|
||||||
|
|
||||||
# Export precision-converted checkpoint if requested
|
# Save prepared model for fast loading next time
|
||||||
if args.export_precision_ckpt:
|
print(f">>> Saving prepared model to {prepared_path} ...")
|
||||||
export_path = args.export_precision_ckpt
|
torch.save(model, prepared_path)
|
||||||
os.makedirs(os.path.dirname(export_path) or '.', exist_ok=True)
|
print(f">>> Prepared model saved ({os.path.getsize(prepared_path) / 1024**3:.1f} GB).")
|
||||||
torch.save({"state_dict": model.state_dict()}, export_path)
|
|
||||||
print(f">>> Precision-converted checkpoint saved to: {export_path}")
|
|
||||||
return
|
|
||||||
|
|
||||||
# Build unnomalizer
|
# ---- FP16: convert diffusion backbone + conditioning modules ----
|
||||||
|
model.model.to(torch.float16)
|
||||||
|
model.model.diffusion_model.dtype = torch.float16
|
||||||
|
print(">>> Diffusion backbone (model.model) converted to FP16.")
|
||||||
|
|
||||||
|
# Projectors / MLP → FP16
|
||||||
|
model.image_proj_model.half()
|
||||||
|
model.state_projector.half()
|
||||||
|
model.action_projector.half()
|
||||||
|
print(">>> Projectors (image_proj_model, state_projector, action_projector) converted to FP16.")
|
||||||
|
|
||||||
|
# Text/image encoders → FP16
|
||||||
|
model.cond_stage_model.half()
|
||||||
|
model.embedder.half()
|
||||||
|
print(">>> Encoders (cond_stage_model, embedder) converted to FP16.")
|
||||||
|
|
||||||
|
# Build normalizer (always needed, independent of model loading path)
|
||||||
logging.info("***** Configing Data *****")
|
logging.info("***** Configing Data *****")
|
||||||
data = instantiate_from_config(config.data)
|
data = instantiate_from_config(config.data)
|
||||||
data.setup()
|
data.setup()
|
||||||
print(">>> Dataset is successfully loaded ...")
|
print(">>> Dataset is successfully loaded ...")
|
||||||
|
|
||||||
model = model.cuda(gpu_no)
|
|
||||||
device = get_device_from_parameters(model)
|
device = get_device_from_parameters(model)
|
||||||
|
|
||||||
|
# Fuse KV projections in attention layers (to_k + to_v → to_kv)
|
||||||
|
from unifolm_wma.modules.attention import CrossAttention
|
||||||
|
kv_count = sum(1 for m in model.modules()
|
||||||
|
if isinstance(m, CrossAttention) and m.fuse_kv())
|
||||||
|
print(f" ✓ KV fused: {kv_count} attention layers")
|
||||||
|
|
||||||
# Run over data
|
# Run over data
|
||||||
assert (args.height % 16 == 0) and (
|
assert (args.height % 16 == 0) and (
|
||||||
args.width % 16
|
args.width % 16
|
||||||
@@ -686,8 +675,13 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
|||||||
# For saving environmental changes in world-model
|
# For saving environmental changes in world-model
|
||||||
sample_save_dir = f'{video_save_dir}/wm/{fs}'
|
sample_save_dir = f'{video_save_dir}/wm/{fs}'
|
||||||
os.makedirs(sample_save_dir, exist_ok=True)
|
os.makedirs(sample_save_dir, exist_ok=True)
|
||||||
# For collecting interaction videos
|
# Writer process for incremental video saving
|
||||||
wm_video = []
|
sample_full_video_file = f"{video_save_dir}/../{sample['videoid']}_full_fs{fs}.mp4"
|
||||||
|
write_q = mp.Queue()
|
||||||
|
writer_proc = mp.Process(
|
||||||
|
target=_video_writer_process,
|
||||||
|
args=(write_q, sample_full_video_file, args.save_fps))
|
||||||
|
writer_proc.start()
|
||||||
# Initialize observation queues
|
# Initialize observation queues
|
||||||
cond_obs_queues = {
|
cond_obs_queues = {
|
||||||
"observation.images.top":
|
"observation.images.top":
|
||||||
@@ -743,7 +737,7 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
|||||||
|
|
||||||
# Use world-model in policy to generate action
|
# Use world-model in policy to generate action
|
||||||
print(f'>>> Step {itr}: generating actions ...')
|
print(f'>>> Step {itr}: generating actions ...')
|
||||||
pred_videos_0, pred_actions, _ = image_guided_synthesis_sim_mode(
|
pred_videos_0, pred_actions, _, _ = image_guided_synthesis_sim_mode(
|
||||||
model,
|
model,
|
||||||
sample['instruction'],
|
sample['instruction'],
|
||||||
observation,
|
observation,
|
||||||
@@ -785,7 +779,7 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
|||||||
|
|
||||||
# Interaction with the world-model
|
# Interaction with the world-model
|
||||||
print(f'>>> Step {itr}: interacting with world model ...')
|
print(f'>>> Step {itr}: interacting with world model ...')
|
||||||
pred_videos_1, _, pred_states = image_guided_synthesis_sim_mode(
|
pred_videos_1, _, pred_states, wm_samples = image_guided_synthesis_sim_mode(
|
||||||
model,
|
model,
|
||||||
"",
|
"",
|
||||||
observation,
|
observation,
|
||||||
@@ -798,12 +792,16 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
|||||||
fs=model_input_fs,
|
fs=model_input_fs,
|
||||||
text_input=False,
|
text_input=False,
|
||||||
timestep_spacing=args.timestep_spacing,
|
timestep_spacing=args.timestep_spacing,
|
||||||
guidance_rescale=args.guidance_rescale)
|
guidance_rescale=args.guidance_rescale,
|
||||||
|
decode_video=False)
|
||||||
|
|
||||||
|
# Decode only the last frame for CLIP embedding in next iteration
|
||||||
|
last_frame_pixel = model.decode_first_stage(wm_samples[:, :, -1:, :, :])
|
||||||
|
|
||||||
for idx in range(args.exe_steps):
|
for idx in range(args.exe_steps):
|
||||||
observation = {
|
observation = {
|
||||||
'observation.images.top':
|
'observation.images.top':
|
||||||
pred_videos_1[0][:, idx:idx + 1].permute(1, 0, 2, 3),
|
last_frame_pixel[0, :, 0:1].permute(1, 0, 2, 3),
|
||||||
'observation.state':
|
'observation.state':
|
||||||
torch.zeros_like(pred_states[0][idx:idx + 1]) if
|
torch.zeros_like(pred_states[0][idx:idx + 1]) if
|
||||||
args.zero_pred_state else pred_states[0][idx:idx + 1],
|
args.zero_pred_state else pred_states[0][idx:idx + 1],
|
||||||
@@ -814,44 +812,26 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
|||||||
cond_obs_queues = populate_queues(cond_obs_queues,
|
cond_obs_queues = populate_queues(cond_obs_queues,
|
||||||
observation)
|
observation)
|
||||||
|
|
||||||
# Save the imagen videos for decision-making
|
# Save the imagen videos for decision-making (async)
|
||||||
if pred_videos_0 is not None:
|
if pred_videos_0 is not None:
|
||||||
sample_tag = f"{args.dataset}-vid{sample['videoid']}-dm-fs-{fs}/itr-{itr}"
|
sample_tag = f"{args.dataset}-vid{sample['videoid']}-dm-fs-{fs}/itr-{itr}"
|
||||||
log_to_tensorboard(writer,
|
log_to_tensorboard_async(writer,
|
||||||
pred_videos_0,
|
pred_videos_0,
|
||||||
sample_tag,
|
sample_tag,
|
||||||
fps=args.save_fps)
|
fps=args.save_fps)
|
||||||
# Save videos environment changes via world-model interaction
|
|
||||||
sample_tag = f"{args.dataset}-vid{sample['videoid']}-wd-fs-{fs}/itr-{itr}"
|
|
||||||
log_to_tensorboard(writer,
|
|
||||||
pred_videos_1,
|
|
||||||
sample_tag,
|
|
||||||
fps=args.save_fps)
|
|
||||||
|
|
||||||
# Save the imagen videos for decision-making
|
|
||||||
if pred_videos_0 is not None:
|
|
||||||
sample_video_file = f'{video_save_dir}/dm/{fs}/itr-{itr}.mp4'
|
|
||||||
save_results(pred_videos_0.cpu(),
|
|
||||||
sample_video_file,
|
|
||||||
fps=args.save_fps)
|
|
||||||
# Save videos environment changes via world-model interaction
|
|
||||||
sample_video_file = f'{video_save_dir}/wm/{fs}/itr-{itr}.mp4'
|
|
||||||
save_results(pred_videos_1.cpu(),
|
|
||||||
sample_video_file,
|
|
||||||
fps=args.save_fps)
|
|
||||||
|
|
||||||
print('>' * 24)
|
print('>' * 24)
|
||||||
# Collect the result of world-model interactions
|
# Decode segment and send to writer process
|
||||||
wm_video.append(pred_videos_1[:, :, :args.exe_steps].cpu())
|
seg_video = model.decode_first_stage(
|
||||||
|
wm_samples[:, :, :args.exe_steps]).detach().cpu()
|
||||||
|
write_q.put(seg_video)
|
||||||
|
|
||||||
full_video = torch.cat(wm_video, dim=2)
|
# Stop writer process
|
||||||
sample_tag = f"{args.dataset}-vid{sample['videoid']}-wd-fs-{fs}/full"
|
write_q.put(None)
|
||||||
log_to_tensorboard(writer,
|
writer_proc.join()
|
||||||
full_video,
|
|
||||||
sample_tag,
|
# Wait for all async I/O to complete
|
||||||
fps=args.save_fps)
|
_flush_io()
|
||||||
sample_full_video_file = f"{video_save_dir}/../{sample['videoid']}_full_fs{fs}.mp4"
|
|
||||||
save_results(full_video, sample_full_video_file, fps=args.save_fps)
|
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@@ -969,46 +949,16 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--fast_policy_no_decode",
|
"--fast_policy_no_decode",
|
||||||
action='store_true',
|
action='store_true',
|
||||||
default=False,
|
default=True,
|
||||||
help="Speed mode: policy pass only predicts actions, skip policy video decode/log/save.")
|
help="Speed mode: policy pass only predicts actions, skip policy video decode/log/save.")
|
||||||
parser.add_argument("--save_fps",
|
parser.add_argument("--save_fps",
|
||||||
type=int,
|
type=int,
|
||||||
default=8,
|
default=8,
|
||||||
help="fps for the saving video")
|
help="fps for the saving video")
|
||||||
parser.add_argument(
|
|
||||||
"--diffusion_dtype",
|
|
||||||
type=str,
|
|
||||||
choices=["fp32", "bf16"],
|
|
||||||
default="bf16",
|
|
||||||
help="Diffusion backbone precision (fp32/bf16)")
|
|
||||||
parser.add_argument(
|
|
||||||
"--projector_mode",
|
|
||||||
type=str,
|
|
||||||
choices=["fp32", "autocast", "bf16_full"],
|
|
||||||
default="bf16_full",
|
|
||||||
help="Projector precision mode (fp32/autocast/bf16_full)")
|
|
||||||
parser.add_argument(
|
|
||||||
"--encoder_mode",
|
|
||||||
type=str,
|
|
||||||
choices=["fp32", "autocast", "bf16_full"],
|
|
||||||
default="bf16_full",
|
|
||||||
help="Encoder precision mode (fp32/autocast/bf16_full)")
|
|
||||||
parser.add_argument(
|
|
||||||
"--vae_dtype",
|
|
||||||
type=str,
|
|
||||||
choices=["fp32", "bf16"],
|
|
||||||
default="fp32",
|
|
||||||
help="VAE precision (fp32/bf16, most affects image quality)")
|
|
||||||
parser.add_argument(
|
|
||||||
"--export_precision_ckpt",
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help="Export precision-converted checkpoint to this path, then exit.")
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
patch_norm_bypass_autocast()
|
|
||||||
parser = get_parser()
|
parser = get_parser()
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
seed = args.seed
|
seed = args.seed
|
||||||
|
|||||||
@@ -11,6 +11,9 @@ from unifolm_wma.utils.utils import instantiate_from_config
|
|||||||
from unifolm_wma.utils.train import get_trainer_callbacks, get_trainer_logger, get_trainer_strategy
|
from unifolm_wma.utils.train import get_trainer_callbacks, get_trainer_logger, get_trainer_strategy
|
||||||
from unifolm_wma.utils.train import set_logger, init_workspace, load_checkpoints, get_num_parameters
|
from unifolm_wma.utils.train import set_logger, init_workspace, load_checkpoints, get_num_parameters
|
||||||
|
|
||||||
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
torch.backends.cudnn.allow_tf32 = True
|
||||||
|
|
||||||
|
|
||||||
def get_parser(**parser_kwargs):
|
def get_parser(**parser_kwargs):
|
||||||
parser = argparse.ArgumentParser(**parser_kwargs)
|
parser = argparse.ArgumentParser(**parser_kwargs)
|
||||||
|
|||||||
@@ -988,7 +988,7 @@ class LatentDiffusion(DDPM):
|
|||||||
|
|
||||||
def instantiate_cond_stage(self, config: OmegaConf) -> None:
|
def instantiate_cond_stage(self, config: OmegaConf) -> None:
|
||||||
"""
|
"""
|
||||||
Build the conditioning stage model.
|
Build the conditioning stage model. Frozen models are converted to FP16.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
config: OmegaConf config describing the conditioning model to instantiate.
|
config: OmegaConf config describing the conditioning model to instantiate.
|
||||||
@@ -1000,6 +1000,7 @@ class LatentDiffusion(DDPM):
|
|||||||
self.cond_stage_model.train = disabled_train
|
self.cond_stage_model.train = disabled_train
|
||||||
for param in self.cond_stage_model.parameters():
|
for param in self.cond_stage_model.parameters():
|
||||||
param.requires_grad = False
|
param.requires_grad = False
|
||||||
|
self.cond_stage_model.half()
|
||||||
else:
|
else:
|
||||||
model = instantiate_from_config(config)
|
model = instantiate_from_config(config)
|
||||||
self.cond_stage_model = model
|
self.cond_stage_model = model
|
||||||
@@ -1014,17 +1015,18 @@ class LatentDiffusion(DDPM):
|
|||||||
Returns:
|
Returns:
|
||||||
Conditioning embedding as a tensor (shape depends on cond model).
|
Conditioning embedding as a tensor (shape depends on cond model).
|
||||||
"""
|
"""
|
||||||
if self.cond_stage_forward is None:
|
with torch.cuda.amp.autocast(dtype=torch.float16):
|
||||||
if hasattr(self.cond_stage_model, 'encode') and callable(
|
if self.cond_stage_forward is None:
|
||||||
self.cond_stage_model.encode):
|
if hasattr(self.cond_stage_model, 'encode') and callable(
|
||||||
c = self.cond_stage_model.encode(c)
|
self.cond_stage_model.encode):
|
||||||
if isinstance(c, DiagonalGaussianDistribution):
|
c = self.cond_stage_model.encode(c)
|
||||||
c = c.mode()
|
if isinstance(c, DiagonalGaussianDistribution):
|
||||||
|
c = c.mode()
|
||||||
|
else:
|
||||||
|
c = self.cond_stage_model(c)
|
||||||
else:
|
else:
|
||||||
c = self.cond_stage_model(c)
|
assert hasattr(self.cond_stage_model, self.cond_stage_forward)
|
||||||
else:
|
c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
|
||||||
assert hasattr(self.cond_stage_model, self.cond_stage_forward)
|
|
||||||
c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
|
|
||||||
return c
|
return c
|
||||||
|
|
||||||
def get_first_stage_encoding(
|
def get_first_stage_encoding(
|
||||||
@@ -1105,10 +1107,6 @@ class LatentDiffusion(DDPM):
|
|||||||
else:
|
else:
|
||||||
reshape_back = False
|
reshape_back = False
|
||||||
|
|
||||||
# Align input dtype with VAE weights (e.g. fp32 samples → bf16 VAE)
|
|
||||||
vae_dtype = next(self.first_stage_model.parameters()).dtype
|
|
||||||
z = z.to(dtype=vae_dtype)
|
|
||||||
|
|
||||||
if not self.perframe_ae:
|
if not self.perframe_ae:
|
||||||
z = 1. / self.scale_factor * z
|
z = 1. / self.scale_factor * z
|
||||||
results = self.first_stage_model.decode(z, **kwargs)
|
results = self.first_stage_model.decode(z, **kwargs)
|
||||||
@@ -1803,9 +1801,7 @@ class LatentDiffusion(DDPM):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
if ddim:
|
if ddim:
|
||||||
if not hasattr(self, '_ddim_sampler') or self._ddim_sampler is None:
|
ddim_sampler = DDIMSampler(self)
|
||||||
self._ddim_sampler = DDIMSampler(self)
|
|
||||||
ddim_sampler = self._ddim_sampler
|
|
||||||
shape = (self.channels, self.temporal_length, *self.image_size)
|
shape = (self.channels, self.temporal_length, *self.image_size)
|
||||||
samples, actions, states, intermediates = ddim_sampler.sample(
|
samples, actions, states, intermediates = ddim_sampler.sample(
|
||||||
ddim_steps, batch_size, shape, cond, verbose=False, **kwargs)
|
ddim_steps, batch_size, shape, cond, verbose=False, **kwargs)
|
||||||
@@ -1963,6 +1959,7 @@ class LatentVisualDiffusion(LatentDiffusion):
|
|||||||
self.image_proj_model.train = disabled_train
|
self.image_proj_model.train = disabled_train
|
||||||
for param in self.image_proj_model.parameters():
|
for param in self.image_proj_model.parameters():
|
||||||
param.requires_grad = False
|
param.requires_grad = False
|
||||||
|
self.image_proj_model.half()
|
||||||
|
|
||||||
def _init_embedder(self, config: OmegaConf, freeze: bool = True) -> None:
|
def _init_embedder(self, config: OmegaConf, freeze: bool = True) -> None:
|
||||||
"""
|
"""
|
||||||
@@ -1978,6 +1975,7 @@ class LatentVisualDiffusion(LatentDiffusion):
|
|||||||
self.embedder.train = disabled_train
|
self.embedder.train = disabled_train
|
||||||
for param in self.embedder.parameters():
|
for param in self.embedder.parameters():
|
||||||
param.requires_grad = False
|
param.requires_grad = False
|
||||||
|
self.embedder.half()
|
||||||
|
|
||||||
def init_normalizers(self, normalize_config: OmegaConf,
|
def init_normalizers(self, normalize_config: OmegaConf,
|
||||||
dataset_stats: Mapping[str, Any]) -> None:
|
dataset_stats: Mapping[str, Any]) -> None:
|
||||||
@@ -2181,8 +2179,9 @@ class LatentVisualDiffusion(LatentDiffusion):
|
|||||||
(random_num < 3 * self.uncond_prob).float(), "n -> n 1 1 1")
|
(random_num < 3 * self.uncond_prob).float(), "n -> n 1 1 1")
|
||||||
|
|
||||||
cond_img = input_mask * img
|
cond_img = input_mask * img
|
||||||
cond_img_emb = self.embedder(cond_img)
|
with torch.cuda.amp.autocast(dtype=torch.float16):
|
||||||
cond_img_emb = self.image_proj_model(cond_img_emb)
|
cond_img_emb = self.embedder(cond_img)
|
||||||
|
cond_img_emb = self.image_proj_model(cond_img_emb)
|
||||||
|
|
||||||
if self.model.conditioning_key == 'hybrid':
|
if self.model.conditioning_key == 'hybrid':
|
||||||
if self.interp_mode:
|
if self.interp_mode:
|
||||||
@@ -2197,11 +2196,12 @@ class LatentVisualDiffusion(LatentDiffusion):
|
|||||||
repeat=z.shape[2])
|
repeat=z.shape[2])
|
||||||
cond["c_concat"] = [img_cat_cond]
|
cond["c_concat"] = [img_cat_cond]
|
||||||
|
|
||||||
cond_action = self.action_projector(action)
|
with torch.cuda.amp.autocast(dtype=torch.float16):
|
||||||
cond_action_emb = self.agent_action_pos_emb + cond_action
|
cond_action = self.action_projector(action)
|
||||||
# Get conditioning states
|
cond_action_emb = self.agent_action_pos_emb + cond_action
|
||||||
cond_state = self.state_projector(obs_state)
|
# Get conditioning states
|
||||||
cond_state_emb = self.agent_state_pos_emb + cond_state
|
cond_state = self.state_projector(obs_state)
|
||||||
|
cond_state_emb = self.agent_state_pos_emb + cond_state
|
||||||
|
|
||||||
if self.decision_making_only:
|
if self.decision_making_only:
|
||||||
is_sim_mode = False
|
is_sim_mode = False
|
||||||
@@ -2463,6 +2463,17 @@ class DiffusionWrapper(pl.LightningModule):
|
|||||||
Returns:
|
Returns:
|
||||||
Output from the inner diffusion model (tensor or tuple, depending on the model).
|
Output from the inner diffusion model (tensor or tuple, depending on the model).
|
||||||
"""
|
"""
|
||||||
|
with torch.cuda.amp.autocast(dtype=torch.float16):
|
||||||
|
return self._forward_impl(x, x_action, x_state, t,
|
||||||
|
c_concat, c_crossattn, c_crossattn_action,
|
||||||
|
c_adm, s, mask, **kwargs)
|
||||||
|
|
||||||
|
def _forward_impl(
|
||||||
|
self,
|
||||||
|
x, x_action, x_state, t,
|
||||||
|
c_concat=None, c_crossattn=None, c_crossattn_action=None,
|
||||||
|
c_adm=None, s=None, mask=None, **kwargs,
|
||||||
|
):
|
||||||
if self.conditioning_key is None:
|
if self.conditioning_key is None:
|
||||||
out = self.diffusion_model(x, t)
|
out = self.diffusion_model(x, t)
|
||||||
elif self.conditioning_key == 'concat':
|
elif self.conditioning_key == 'concat':
|
||||||
|
|||||||
@@ -8,14 +8,12 @@ class SinusoidalPosEmb(nn.Module):
|
|||||||
def __init__(self, dim):
|
def __init__(self, dim):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
# Dummy buffer so .to(dtype) propagates to this module
|
|
||||||
self.register_buffer('_dtype_buf', torch.zeros(1), persistent=False)
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
device = x.device
|
device = x.device
|
||||||
half_dim = self.dim // 2
|
half_dim = self.dim // 2
|
||||||
emb = math.log(10000) / (half_dim - 1)
|
emb = math.log(10000) / (half_dim - 1)
|
||||||
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
|
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
|
||||||
emb = x.float()[:, None] * emb[None, :]
|
emb = x[:, None] * emb[None, :]
|
||||||
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
||||||
return emb.to(self._dtype_buf.dtype)
|
return emb
|
||||||
|
|||||||
@@ -18,7 +18,6 @@ class DDIMSampler(object):
|
|||||||
self.ddpm_num_timesteps = model.num_timesteps
|
self.ddpm_num_timesteps = model.num_timesteps
|
||||||
self.schedule = schedule
|
self.schedule = schedule
|
||||||
self.counter = 0
|
self.counter = 0
|
||||||
self._schedule_key = None # (ddim_num_steps, ddim_discretize, ddim_eta)
|
|
||||||
|
|
||||||
def register_buffer(self, name, attr):
|
def register_buffer(self, name, attr):
|
||||||
if type(attr) == torch.Tensor:
|
if type(attr) == torch.Tensor:
|
||||||
@@ -31,11 +30,6 @@ class DDIMSampler(object):
|
|||||||
ddim_discretize="uniform",
|
ddim_discretize="uniform",
|
||||||
ddim_eta=0.,
|
ddim_eta=0.,
|
||||||
verbose=True):
|
verbose=True):
|
||||||
key = (ddim_num_steps, ddim_discretize, ddim_eta)
|
|
||||||
if self._schedule_key == key:
|
|
||||||
return
|
|
||||||
self._schedule_key = key
|
|
||||||
|
|
||||||
self.ddim_timesteps = make_ddim_timesteps(
|
self.ddim_timesteps = make_ddim_timesteps(
|
||||||
ddim_discr_method=ddim_discretize,
|
ddim_discr_method=ddim_discretize,
|
||||||
num_ddim_timesteps=ddim_num_steps,
|
num_ddim_timesteps=ddim_num_steps,
|
||||||
@@ -44,7 +38,7 @@ class DDIMSampler(object):
|
|||||||
alphas_cumprod = self.model.alphas_cumprod
|
alphas_cumprod = self.model.alphas_cumprod
|
||||||
assert alphas_cumprod.shape[
|
assert alphas_cumprod.shape[
|
||||||
0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
|
0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
|
||||||
to_torch = lambda x: x.clone().detach().to(torch.float64).to(self.model
|
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model
|
||||||
.device)
|
.device)
|
||||||
|
|
||||||
if self.model.use_dynamic_rescale:
|
if self.model.use_dynamic_rescale:
|
||||||
@@ -217,9 +211,9 @@ class DDIMSampler(object):
|
|||||||
|
|
||||||
if precision is not None:
|
if precision is not None:
|
||||||
if precision == 16:
|
if precision == 16:
|
||||||
img = img.to(dtype=torch.bfloat16)
|
img = img.to(dtype=torch.float16)
|
||||||
action = action.to(dtype=torch.bfloat16)
|
action = action.to(dtype=torch.float16)
|
||||||
state = state.to(dtype=torch.bfloat16)
|
state = state.to(dtype=torch.float16)
|
||||||
|
|
||||||
if timesteps is None:
|
if timesteps is None:
|
||||||
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
|
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
|
||||||
@@ -390,10 +384,10 @@ class DDIMSampler(object):
|
|||||||
sigmas = self.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
|
sigmas = self.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
|
||||||
|
|
||||||
# Use 0-d tensors directly (already on device); broadcasting handles shape
|
# Use 0-d tensors directly (already on device); broadcasting handles shape
|
||||||
a_t = alphas[index].to(x.dtype)
|
a_t = alphas[index]
|
||||||
a_prev = alphas_prev[index].to(x.dtype)
|
a_prev = alphas_prev[index]
|
||||||
sigma_t = sigmas[index].to(x.dtype)
|
sigma_t = sigmas[index]
|
||||||
sqrt_one_minus_at = sqrt_one_minus_alphas[index].to(x.dtype)
|
sqrt_one_minus_at = sqrt_one_minus_alphas[index]
|
||||||
|
|
||||||
if self.model.parameterization != "v":
|
if self.model.parameterization != "v":
|
||||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||||
|
|||||||
@@ -86,8 +86,9 @@ class CrossAttention(nn.Module):
|
|||||||
self.relative_position_v = RelativePosition(
|
self.relative_position_v = RelativePosition(
|
||||||
num_units=dim_head, max_relative_position=temporal_length)
|
num_units=dim_head, max_relative_position=temporal_length)
|
||||||
else:
|
else:
|
||||||
## bmm fused-scale attention for all non-relative-position cases
|
## only used for spatial attention, while NOT for temporal attention
|
||||||
self.forward = self.bmm_forward
|
if XFORMERS_IS_AVAILBLE and temporal_length is None:
|
||||||
|
self.forward = self.efficient_forward
|
||||||
|
|
||||||
self.video_length = video_length
|
self.video_length = video_length
|
||||||
self.image_cross_attention = image_cross_attention
|
self.image_cross_attention = image_cross_attention
|
||||||
@@ -99,6 +100,7 @@ class CrossAttention(nn.Module):
|
|||||||
self.agent_action_context_len = agent_action_context_len
|
self.agent_action_context_len = agent_action_context_len
|
||||||
self._kv_cache = {}
|
self._kv_cache = {}
|
||||||
self._kv_cache_enabled = False
|
self._kv_cache_enabled = False
|
||||||
|
self._kv_fused = False
|
||||||
|
|
||||||
self.cross_attention_scale_learnable = cross_attention_scale_learnable
|
self.cross_attention_scale_learnable = cross_attention_scale_learnable
|
||||||
if self.image_cross_attention:
|
if self.image_cross_attention:
|
||||||
@@ -116,6 +118,27 @@ class CrossAttention(nn.Module):
|
|||||||
self.register_parameter('alpha_caa',
|
self.register_parameter('alpha_caa',
|
||||||
nn.Parameter(torch.tensor(0.)))
|
nn.Parameter(torch.tensor(0.)))
|
||||||
|
|
||||||
|
def fuse_kv(self):
|
||||||
|
"""Fuse to_k/to_v into to_kv (2 Linear → 1). Works for all layers."""
|
||||||
|
k_w = self.to_k.weight # (inner_dim, context_dim)
|
||||||
|
v_w = self.to_v.weight
|
||||||
|
self.to_kv = nn.Linear(k_w.shape[1], k_w.shape[0] * 2, bias=False)
|
||||||
|
self.to_kv.weight = nn.Parameter(torch.cat([k_w, v_w], dim=0))
|
||||||
|
del self.to_k, self.to_v
|
||||||
|
if self.image_cross_attention:
|
||||||
|
for suffix in ('_ip', '_as', '_aa'):
|
||||||
|
k_attr = f'to_k{suffix}'
|
||||||
|
v_attr = f'to_v{suffix}'
|
||||||
|
kw = getattr(self, k_attr).weight
|
||||||
|
vw = getattr(self, v_attr).weight
|
||||||
|
fused = nn.Linear(kw.shape[1], kw.shape[0] * 2, bias=False)
|
||||||
|
fused.weight = nn.Parameter(torch.cat([kw, vw], dim=0))
|
||||||
|
setattr(self, f'to_kv{suffix}', fused)
|
||||||
|
delattr(self, k_attr)
|
||||||
|
delattr(self, v_attr)
|
||||||
|
self._kv_fused = True
|
||||||
|
return True
|
||||||
|
|
||||||
def forward(self, x, context=None, mask=None):
|
def forward(self, x, context=None, mask=None):
|
||||||
spatial_self_attn = (context is None)
|
spatial_self_attn = (context is None)
|
||||||
k_ip, v_ip, out_ip = None, None, None
|
k_ip, v_ip, out_ip = None, None, None
|
||||||
@@ -127,7 +150,7 @@ class CrossAttention(nn.Module):
|
|||||||
context = default(context, x)
|
context = default(context, x)
|
||||||
|
|
||||||
if self.image_cross_attention and not spatial_self_attn:
|
if self.image_cross_attention and not spatial_self_attn:
|
||||||
# assert 1 > 2, ">>> ERROR: should setup xformers and use efficient_forward ..."
|
assert 1 > 2, ">>> ERROR: should setup xformers and use efficient_forward ..."
|
||||||
context_agent_state = context[:, :self.agent_state_context_len, :]
|
context_agent_state = context[:, :self.agent_state_context_len, :]
|
||||||
context_agent_action = context[:,
|
context_agent_action = context[:,
|
||||||
self.agent_state_context_len:self.
|
self.agent_state_context_len:self.
|
||||||
@@ -142,19 +165,28 @@ class CrossAttention(nn.Module):
|
|||||||
self.agent_action_context_len +
|
self.agent_action_context_len +
|
||||||
self.text_context_len:, :]
|
self.text_context_len:, :]
|
||||||
|
|
||||||
k = self.to_k(context_ins)
|
if self._kv_fused:
|
||||||
v = self.to_v(context_ins)
|
k, v = self.to_kv(context_ins).chunk(2, dim=-1)
|
||||||
k_ip = self.to_k_ip(context_image)
|
k_ip, v_ip = self.to_kv_ip(context_image).chunk(2, dim=-1)
|
||||||
v_ip = self.to_v_ip(context_image)
|
k_as, v_as = self.to_kv_as(context_agent_state).chunk(2, dim=-1)
|
||||||
k_as = self.to_k_as(context_agent_state)
|
k_aa, v_aa = self.to_kv_aa(context_agent_action).chunk(2, dim=-1)
|
||||||
v_as = self.to_v_as(context_agent_state)
|
else:
|
||||||
k_aa = self.to_k_aa(context_agent_action)
|
k = self.to_k(context_ins)
|
||||||
v_aa = self.to_v_aa(context_agent_action)
|
v = self.to_v(context_ins)
|
||||||
|
k_ip = self.to_k_ip(context_image)
|
||||||
|
v_ip = self.to_v_ip(context_image)
|
||||||
|
k_as = self.to_k_as(context_agent_state)
|
||||||
|
v_as = self.to_v_as(context_agent_state)
|
||||||
|
k_aa = self.to_k_aa(context_agent_action)
|
||||||
|
v_aa = self.to_v_aa(context_agent_action)
|
||||||
else:
|
else:
|
||||||
if not spatial_self_attn:
|
if not spatial_self_attn:
|
||||||
context = context[:, :self.text_context_len, :]
|
context = context[:, :self.text_context_len, :]
|
||||||
k = self.to_k(context)
|
if self._kv_fused:
|
||||||
v = self.to_v(context)
|
k, v = self.to_kv(context).chunk(2, dim=-1)
|
||||||
|
else:
|
||||||
|
k = self.to_k(context)
|
||||||
|
v = self.to_v(context)
|
||||||
|
|
||||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
|
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
|
||||||
(q, k, v))
|
(q, k, v))
|
||||||
@@ -175,8 +207,7 @@ class CrossAttention(nn.Module):
|
|||||||
sim.masked_fill_(~(mask > 0.5), max_neg_value)
|
sim.masked_fill_(~(mask > 0.5), max_neg_value)
|
||||||
|
|
||||||
# attention, what we cannot get enough of
|
# attention, what we cannot get enough of
|
||||||
with torch.amp.autocast('cuda', enabled=False):
|
sim = sim.softmax(dim=-1)
|
||||||
sim = sim.softmax(dim=-1)
|
|
||||||
|
|
||||||
out = torch.einsum('b i j, b j d -> b i d', sim, v)
|
out = torch.einsum('b i j, b j d -> b i d', sim, v)
|
||||||
if self.relative_position:
|
if self.relative_position:
|
||||||
@@ -193,8 +224,7 @@ class CrossAttention(nn.Module):
|
|||||||
sim_ip = torch.einsum('b i d, b j d -> b i j', q,
|
sim_ip = torch.einsum('b i d, b j d -> b i j', q,
|
||||||
k_ip) * self.scale
|
k_ip) * self.scale
|
||||||
del k_ip
|
del k_ip
|
||||||
with torch.amp.autocast('cuda', enabled=False):
|
sim_ip = sim_ip.softmax(dim=-1)
|
||||||
sim_ip = sim_ip.softmax(dim=-1)
|
|
||||||
out_ip = torch.einsum('b i j, b j d -> b i d', sim_ip, v_ip)
|
out_ip = torch.einsum('b i j, b j d -> b i d', sim_ip, v_ip)
|
||||||
out_ip = rearrange(out_ip, '(b h) n d -> b n (h d)', h=h)
|
out_ip = rearrange(out_ip, '(b h) n d -> b n (h d)', h=h)
|
||||||
|
|
||||||
@@ -205,8 +235,7 @@ class CrossAttention(nn.Module):
|
|||||||
sim_as = torch.einsum('b i d, b j d -> b i j', q,
|
sim_as = torch.einsum('b i d, b j d -> b i j', q,
|
||||||
k_as) * self.scale
|
k_as) * self.scale
|
||||||
del k_as
|
del k_as
|
||||||
with torch.amp.autocast('cuda', enabled=False):
|
sim_as = sim_as.softmax(dim=-1)
|
||||||
sim_as = sim_as.softmax(dim=-1)
|
|
||||||
out_as = torch.einsum('b i j, b j d -> b i d', sim_as, v_as)
|
out_as = torch.einsum('b i j, b j d -> b i d', sim_as, v_as)
|
||||||
out_as = rearrange(out_as, '(b h) n d -> b n (h d)', h=h)
|
out_as = rearrange(out_as, '(b h) n d -> b n (h d)', h=h)
|
||||||
|
|
||||||
@@ -217,8 +246,7 @@ class CrossAttention(nn.Module):
|
|||||||
sim_aa = torch.einsum('b i d, b j d -> b i j', q,
|
sim_aa = torch.einsum('b i d, b j d -> b i j', q,
|
||||||
k_aa) * self.scale
|
k_aa) * self.scale
|
||||||
del k_aa
|
del k_aa
|
||||||
with torch.amp.autocast('cuda', enabled=False):
|
sim_aa = sim_aa.softmax(dim=-1)
|
||||||
sim_aa = sim_aa.softmax(dim=-1)
|
|
||||||
out_aa = torch.einsum('b i j, b j d -> b i d', sim_aa, v_aa)
|
out_aa = torch.einsum('b i j, b j d -> b i d', sim_aa, v_aa)
|
||||||
out_aa = rearrange(out_aa, '(b h) n d -> b n (h d)', h=h)
|
out_aa = rearrange(out_aa, '(b h) n d -> b n (h d)', h=h)
|
||||||
|
|
||||||
@@ -236,276 +264,168 @@ class CrossAttention(nn.Module):
|
|||||||
|
|
||||||
return self.to_out(out)
|
return self.to_out(out)
|
||||||
|
|
||||||
def bmm_forward(self, x, context=None, mask=None):
|
|
||||||
spatial_self_attn = (context is None)
|
|
||||||
k_ip, v_ip, out_ip = None, None, None
|
|
||||||
k_as, v_as, out_as = None, None, None
|
|
||||||
k_aa, v_aa, out_aa = None, None, None
|
|
||||||
|
|
||||||
h = self.heads
|
|
||||||
q = self.to_q(x)
|
|
||||||
context = default(context, x)
|
|
||||||
|
|
||||||
use_cache = self._kv_cache_enabled and not spatial_self_attn
|
|
||||||
cache_hit = use_cache and len(self._kv_cache) > 0
|
|
||||||
|
|
||||||
if cache_hit:
|
|
||||||
# Reuse cached K/V (already in (b*h, n, d) shape)
|
|
||||||
k = self._kv_cache['k']
|
|
||||||
v = self._kv_cache['v']
|
|
||||||
if 'k_ip' in self._kv_cache:
|
|
||||||
k_ip = self._kv_cache['k_ip']
|
|
||||||
v_ip = self._kv_cache['v_ip']
|
|
||||||
k_as = self._kv_cache['k_as']
|
|
||||||
v_as = self._kv_cache['v_as']
|
|
||||||
k_aa = self._kv_cache['k_aa']
|
|
||||||
v_aa = self._kv_cache['v_aa']
|
|
||||||
q = rearrange(q, 'b n (h d) -> (b h) n d', h=h)
|
|
||||||
elif self.image_cross_attention and not spatial_self_attn:
|
|
||||||
context_agent_state = context[:, :self.agent_state_context_len, :]
|
|
||||||
context_agent_action = context[:,
|
|
||||||
self.agent_state_context_len:self.
|
|
||||||
agent_state_context_len +
|
|
||||||
self.agent_action_context_len, :]
|
|
||||||
context_ins = context[:, self.agent_state_context_len +
|
|
||||||
self.agent_action_context_len:self.
|
|
||||||
agent_state_context_len +
|
|
||||||
self.agent_action_context_len +
|
|
||||||
self.text_context_len, :]
|
|
||||||
context_image = context[:, self.agent_state_context_len +
|
|
||||||
self.agent_action_context_len +
|
|
||||||
self.text_context_len:, :]
|
|
||||||
|
|
||||||
k = self.to_k(context_ins)
|
|
||||||
v = self.to_v(context_ins)
|
|
||||||
k_ip = self.to_k_ip(context_image)
|
|
||||||
v_ip = self.to_v_ip(context_image)
|
|
||||||
k_as = self.to_k_as(context_agent_state)
|
|
||||||
v_as = self.to_v_as(context_agent_state)
|
|
||||||
k_aa = self.to_k_aa(context_agent_action)
|
|
||||||
v_aa = self.to_v_aa(context_agent_action)
|
|
||||||
|
|
||||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
|
|
||||||
(q, k, v))
|
|
||||||
k_ip, v_ip = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
|
|
||||||
(k_ip, v_ip))
|
|
||||||
k_as, v_as = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
|
|
||||||
(k_as, v_as))
|
|
||||||
k_aa, v_aa = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
|
|
||||||
(k_aa, v_aa))
|
|
||||||
|
|
||||||
if use_cache:
|
|
||||||
self._kv_cache = {
|
|
||||||
'k': k, 'v': v,
|
|
||||||
'k_ip': k_ip, 'v_ip': v_ip,
|
|
||||||
'k_as': k_as, 'v_as': v_as,
|
|
||||||
'k_aa': k_aa, 'v_aa': v_aa,
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
if not spatial_self_attn:
|
|
||||||
context = context[:, :self.text_context_len, :]
|
|
||||||
k = self.to_k(context)
|
|
||||||
v = self.to_v(context)
|
|
||||||
|
|
||||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
|
|
||||||
(q, k, v))
|
|
||||||
|
|
||||||
if use_cache:
|
|
||||||
self._kv_cache = {'k': k, 'v': v}
|
|
||||||
|
|
||||||
# baddbmm: fuse scale into GEMM → one kernel instead of matmul + mul
|
|
||||||
sim = torch.baddbmm(
|
|
||||||
torch.empty(q.shape[0], q.shape[1], k.shape[1], dtype=q.dtype, device=q.device),
|
|
||||||
q, k.transpose(-1, -2), beta=0, alpha=self.scale)
|
|
||||||
|
|
||||||
if exists(mask):
|
|
||||||
max_neg_value = -torch.finfo(sim.dtype).max
|
|
||||||
mask = repeat(mask, 'b i j -> (b h) i j', h=h)
|
|
||||||
sim.masked_fill_(~(mask > 0.5), max_neg_value)
|
|
||||||
|
|
||||||
with torch.amp.autocast('cuda', enabled=False):
|
|
||||||
sim = sim.softmax(dim=-1)
|
|
||||||
|
|
||||||
out = torch.bmm(sim, v)
|
|
||||||
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
|
|
||||||
|
|
||||||
if k_ip is not None and k_as is not None and k_aa is not None:
|
|
||||||
## image cross-attention (k_ip/v_ip already in (b*h, n, d) shape)
|
|
||||||
sim_ip = torch.baddbmm(
|
|
||||||
torch.empty(q.shape[0], q.shape[1], k_ip.shape[1], dtype=q.dtype, device=q.device),
|
|
||||||
q, k_ip.transpose(-1, -2), beta=0, alpha=self.scale)
|
|
||||||
with torch.amp.autocast('cuda', enabled=False):
|
|
||||||
sim_ip = sim_ip.softmax(dim=-1)
|
|
||||||
out_ip = torch.bmm(sim_ip, v_ip)
|
|
||||||
out_ip = rearrange(out_ip, '(b h) n d -> b n (h d)', h=h)
|
|
||||||
|
|
||||||
## agent state cross-attention (k_as/v_as already in (b*h, n, d) shape)
|
|
||||||
sim_as = torch.baddbmm(
|
|
||||||
torch.empty(q.shape[0], q.shape[1], k_as.shape[1], dtype=q.dtype, device=q.device),
|
|
||||||
q, k_as.transpose(-1, -2), beta=0, alpha=self.scale)
|
|
||||||
with torch.amp.autocast('cuda', enabled=False):
|
|
||||||
sim_as = sim_as.softmax(dim=-1)
|
|
||||||
out_as = torch.bmm(sim_as, v_as)
|
|
||||||
out_as = rearrange(out_as, '(b h) n d -> b n (h d)', h=h)
|
|
||||||
|
|
||||||
## agent action cross-attention (k_aa/v_aa already in (b*h, n, d) shape)
|
|
||||||
sim_aa = torch.baddbmm(
|
|
||||||
torch.empty(q.shape[0], q.shape[1], k_aa.shape[1], dtype=q.dtype, device=q.device),
|
|
||||||
q, k_aa.transpose(-1, -2), beta=0, alpha=self.scale)
|
|
||||||
with torch.amp.autocast('cuda', enabled=False):
|
|
||||||
sim_aa = sim_aa.softmax(dim=-1)
|
|
||||||
out_aa = torch.bmm(sim_aa, v_aa)
|
|
||||||
out_aa = rearrange(out_aa, '(b h) n d -> b n (h d)', h=h)
|
|
||||||
|
|
||||||
if out_ip is not None and out_as is not None and out_aa is not None:
|
|
||||||
if self.cross_attention_scale_learnable:
|
|
||||||
out = out + \
|
|
||||||
self.image_cross_attention_scale * out_ip * (torch.tanh(self.alpha_ctx) + 1) + \
|
|
||||||
self.agent_state_cross_attention_scale * out_as * (torch.tanh(self.alpha_cas) + 1) + \
|
|
||||||
self.agent_action_cross_attention_scale * out_aa * (torch.tanh(self.alpha_caa) + 1)
|
|
||||||
else:
|
|
||||||
out = out + \
|
|
||||||
self.image_cross_attention_scale * out_ip + \
|
|
||||||
self.agent_state_cross_attention_scale * out_as + \
|
|
||||||
self.agent_action_cross_attention_scale * out_aa
|
|
||||||
|
|
||||||
return self.to_out(out)
|
|
||||||
|
|
||||||
def efficient_forward(self, x, context=None, mask=None):
|
def efficient_forward(self, x, context=None, mask=None):
|
||||||
spatial_self_attn = (context is None)
|
spatial_self_attn = (context is None)
|
||||||
k, v, out = None, None, None
|
k, v, out = None, None, None
|
||||||
k_ip, v_ip, out_ip = None, None, None
|
k_ip, v_ip, out_ip = None, None, None
|
||||||
k_as, v_as, out_as = None, None, None
|
k_as, v_as, out_as = None, None, None
|
||||||
k_aa, v_aa, out_aa = None, None, None
|
k_aa, v_aa, out_aa = None, None, None
|
||||||
|
attn_mask_aa = None
|
||||||
|
|
||||||
|
h = self.heads
|
||||||
q = self.to_q(x)
|
q = self.to_q(x)
|
||||||
context = default(context, x)
|
context = default(context, x)
|
||||||
|
|
||||||
if self.image_cross_attention and not spatial_self_attn:
|
b, _, _ = q.shape
|
||||||
|
q = q.unsqueeze(3).reshape(b, q.shape[1], h, self.dim_head).permute(0, 2, 1, 3).reshape(b * h, q.shape[1], self.dim_head).contiguous()
|
||||||
|
|
||||||
|
def _reshape_kv(t):
|
||||||
|
return t.unsqueeze(3).reshape(b, t.shape[1], h, self.dim_head).permute(0, 2, 1, 3).reshape(b * h, t.shape[1], self.dim_head).contiguous()
|
||||||
|
|
||||||
|
use_cache = self._kv_cache_enabled and not spatial_self_attn
|
||||||
|
cache_hit = use_cache and len(self._kv_cache) > 0
|
||||||
|
|
||||||
|
if cache_hit:
|
||||||
|
k = self._kv_cache['k']
|
||||||
|
v = self._kv_cache['v']
|
||||||
|
k_ip = self._kv_cache.get('k_ip')
|
||||||
|
v_ip = self._kv_cache.get('v_ip')
|
||||||
|
k_as = self._kv_cache.get('k_as')
|
||||||
|
v_as = self._kv_cache.get('v_as')
|
||||||
|
k_aa = self._kv_cache.get('k_aa')
|
||||||
|
v_aa = self._kv_cache.get('v_aa')
|
||||||
|
attn_mask_aa = self._kv_cache.get('attn_mask_aa')
|
||||||
|
elif self.image_cross_attention and not spatial_self_attn:
|
||||||
if context.shape[1] == self.text_context_len + self.video_length:
|
if context.shape[1] == self.text_context_len + self.video_length:
|
||||||
context_ins, context_image = context[:, :self.text_context_len, :], context[:,self.text_context_len:, :]
|
context_ins, context_image = context[:, :self.text_context_len, :], context[:,self.text_context_len:, :]
|
||||||
k = self.to_k(context)
|
if self._kv_fused:
|
||||||
v = self.to_v(context)
|
k, v = self.to_kv(context).chunk(2, dim=-1)
|
||||||
k_ip = self.to_k_ip(context_image)
|
k_ip, v_ip = self.to_kv_ip(context_image).chunk(2, dim=-1)
|
||||||
v_ip = self.to_v_ip(context_image)
|
else:
|
||||||
|
k = self.to_k(context)
|
||||||
|
v = self.to_v(context)
|
||||||
|
k_ip = self.to_k_ip(context_image)
|
||||||
|
v_ip = self.to_v_ip(context_image)
|
||||||
|
k, v = map(_reshape_kv, (k, v))
|
||||||
|
k_ip, v_ip = map(_reshape_kv, (k_ip, v_ip))
|
||||||
|
if use_cache:
|
||||||
|
self._kv_cache = {'k': k, 'v': v, 'k_ip': k_ip, 'v_ip': v_ip}
|
||||||
elif context.shape[1] == self.agent_state_context_len + self.text_context_len + self.video_length:
|
elif context.shape[1] == self.agent_state_context_len + self.text_context_len + self.video_length:
|
||||||
context_agent_state = context[:, :self.agent_state_context_len, :]
|
context_agent_state = context[:, :self.agent_state_context_len, :]
|
||||||
context_ins = context[:, self.agent_state_context_len:self.agent_state_context_len+self.text_context_len, :]
|
context_ins = context[:, self.agent_state_context_len:self.agent_state_context_len+self.text_context_len, :]
|
||||||
context_image = context[:, self.agent_state_context_len+self.text_context_len:, :]
|
context_image = context[:, self.agent_state_context_len+self.text_context_len:, :]
|
||||||
k = self.to_k(context_ins)
|
if self._kv_fused:
|
||||||
v = self.to_v(context_ins)
|
k, v = self.to_kv(context_ins).chunk(2, dim=-1)
|
||||||
k_ip = self.to_k_ip(context_image)
|
k_ip, v_ip = self.to_kv_ip(context_image).chunk(2, dim=-1)
|
||||||
v_ip = self.to_v_ip(context_image)
|
k_as, v_as = self.to_kv_as(context_agent_state).chunk(2, dim=-1)
|
||||||
k_as = self.to_k_as(context_agent_state)
|
else:
|
||||||
v_as = self.to_v_as(context_agent_state)
|
k = self.to_k(context_ins)
|
||||||
|
v = self.to_v(context_ins)
|
||||||
|
k_ip = self.to_k_ip(context_image)
|
||||||
|
v_ip = self.to_v_ip(context_image)
|
||||||
|
k_as = self.to_k_as(context_agent_state)
|
||||||
|
v_as = self.to_v_as(context_agent_state)
|
||||||
|
k, v = map(_reshape_kv, (k, v))
|
||||||
|
k_ip, v_ip = map(_reshape_kv, (k_ip, v_ip))
|
||||||
|
k_as, v_as = map(_reshape_kv, (k_as, v_as))
|
||||||
|
if use_cache:
|
||||||
|
self._kv_cache = {'k': k, 'v': v, 'k_ip': k_ip, 'v_ip': v_ip, 'k_as': k_as, 'v_as': v_as}
|
||||||
else:
|
else:
|
||||||
context_agent_state = context[:, :self.agent_state_context_len, :]
|
context_agent_state = context[:, :self.agent_state_context_len, :]
|
||||||
context_agent_action = context[:, self.agent_state_context_len:self.agent_state_context_len+self.agent_action_context_len, :]
|
context_agent_action = context[:, self.agent_state_context_len:self.agent_state_context_len+self.agent_action_context_len, :]
|
||||||
context_ins = context[:, self.agent_state_context_len+self.agent_action_context_len:self.agent_state_context_len+self.agent_action_context_len+self.text_context_len, :]
|
context_ins = context[:, self.agent_state_context_len+self.agent_action_context_len:self.agent_state_context_len+self.agent_action_context_len+self.text_context_len, :]
|
||||||
context_image = context[:, self.agent_state_context_len+self.agent_action_context_len+self.text_context_len:, :]
|
context_image = context[:, self.agent_state_context_len+self.agent_action_context_len+self.text_context_len:, :]
|
||||||
|
|
||||||
k = self.to_k(context_ins)
|
if self._kv_fused:
|
||||||
v = self.to_v(context_ins)
|
k, v = self.to_kv(context_ins).chunk(2, dim=-1)
|
||||||
k_ip = self.to_k_ip(context_image)
|
k_ip, v_ip = self.to_kv_ip(context_image).chunk(2, dim=-1)
|
||||||
v_ip = self.to_v_ip(context_image)
|
k_as, v_as = self.to_kv_as(context_agent_state).chunk(2, dim=-1)
|
||||||
k_as = self.to_k_as(context_agent_state)
|
k_aa, v_aa = self.to_kv_aa(context_agent_action).chunk(2, dim=-1)
|
||||||
v_as = self.to_v_as(context_agent_state)
|
else:
|
||||||
k_aa = self.to_k_aa(context_agent_action)
|
k = self.to_k(context_ins)
|
||||||
v_aa = self.to_v_aa(context_agent_action)
|
v = self.to_v(context_ins)
|
||||||
|
k_ip = self.to_k_ip(context_image)
|
||||||
|
v_ip = self.to_v_ip(context_image)
|
||||||
|
k_as = self.to_k_as(context_agent_state)
|
||||||
|
v_as = self.to_v_as(context_agent_state)
|
||||||
|
k_aa = self.to_k_aa(context_agent_action)
|
||||||
|
v_aa = self.to_v_aa(context_agent_action)
|
||||||
|
|
||||||
attn_mask_aa = self._get_attn_mask_aa(x.shape[0],
|
k, v = map(_reshape_kv, (k, v))
|
||||||
q.shape[1],
|
k_ip, v_ip = map(_reshape_kv, (k_ip, v_ip))
|
||||||
k_aa.shape[1],
|
k_as, v_as = map(_reshape_kv, (k_as, v_as))
|
||||||
block_size=16,
|
k_aa, v_aa = map(_reshape_kv, (k_aa, v_aa))
|
||||||
device=k_aa.device)
|
|
||||||
|
attn_mask_aa_raw = self._get_attn_mask_aa(x.shape[0],
|
||||||
|
q.shape[1],
|
||||||
|
k_aa.shape[1],
|
||||||
|
block_size=16,
|
||||||
|
device=k_aa.device)
|
||||||
|
attn_mask_aa = attn_mask_aa_raw.unsqueeze(1).repeat(1, h, 1, 1).reshape(
|
||||||
|
b * h, attn_mask_aa_raw.shape[1], attn_mask_aa_raw.shape[2]).to(q.dtype)
|
||||||
|
|
||||||
|
if use_cache:
|
||||||
|
self._kv_cache = {
|
||||||
|
'k': k, 'v': v, 'k_ip': k_ip, 'v_ip': v_ip,
|
||||||
|
'k_as': k_as, 'v_as': v_as, 'k_aa': k_aa, 'v_aa': v_aa,
|
||||||
|
'attn_mask_aa': attn_mask_aa,
|
||||||
|
}
|
||||||
else:
|
else:
|
||||||
if not spatial_self_attn:
|
if not spatial_self_attn:
|
||||||
assert 1 > 2, ">>> ERROR: you should never go into here ..."
|
assert 1 > 2, ">>> ERROR: you should never go into here ..."
|
||||||
context = context[:, :self.text_context_len, :]
|
context = context[:, :self.text_context_len, :]
|
||||||
k = self.to_k(context)
|
if self._kv_fused:
|
||||||
v = self.to_v(context)
|
k, v = self.to_kv(context).chunk(2, dim=-1)
|
||||||
|
else:
|
||||||
b, _, _ = q.shape
|
k = self.to_k(context)
|
||||||
q = q.unsqueeze(3).reshape(b, q.shape[1], self.heads, self.dim_head).permute(0, 2, 1, 3).reshape(b * self.heads, q.shape[1], self.dim_head).contiguous()
|
v = self.to_v(context)
|
||||||
|
k, v = map(_reshape_kv, (k, v))
|
||||||
|
if use_cache:
|
||||||
|
self._kv_cache = {'k': k, 'v': v}
|
||||||
if k is not None:
|
if k is not None:
|
||||||
k, v = map(
|
|
||||||
lambda t: t.unsqueeze(3).reshape(b, t.shape[
|
|
||||||
1], self.heads, self.dim_head).permute(0, 2, 1, 3).reshape(
|
|
||||||
b * self.heads, t.shape[1], self.dim_head).contiguous(),
|
|
||||||
(k, v),
|
|
||||||
)
|
|
||||||
out = xformers.ops.memory_efficient_attention(q,
|
out = xformers.ops.memory_efficient_attention(q,
|
||||||
k,
|
k,
|
||||||
v,
|
v,
|
||||||
attn_bias=None,
|
attn_bias=None,
|
||||||
op=None)
|
op=None)
|
||||||
out = (out.unsqueeze(0).reshape(
|
out = (out.unsqueeze(0).reshape(
|
||||||
b, self.heads, out.shape[1],
|
b, h, out.shape[1],
|
||||||
self.dim_head).permute(0, 2, 1,
|
self.dim_head).permute(0, 2, 1,
|
||||||
3).reshape(b, out.shape[1],
|
3).reshape(b, out.shape[1],
|
||||||
self.heads * self.dim_head))
|
h * self.dim_head))
|
||||||
|
|
||||||
if k_ip is not None:
|
if k_ip is not None:
|
||||||
# For image cross-attention
|
|
||||||
k_ip, v_ip = map(
|
|
||||||
lambda t: t.unsqueeze(3).reshape(b, t.shape[
|
|
||||||
1], self.heads, self.dim_head).permute(0, 2, 1, 3).reshape(
|
|
||||||
b * self.heads, t.shape[1], self.dim_head).contiguous(
|
|
||||||
),
|
|
||||||
(k_ip, v_ip),
|
|
||||||
)
|
|
||||||
out_ip = xformers.ops.memory_efficient_attention(q,
|
out_ip = xformers.ops.memory_efficient_attention(q,
|
||||||
k_ip,
|
k_ip,
|
||||||
v_ip,
|
v_ip,
|
||||||
attn_bias=None,
|
attn_bias=None,
|
||||||
op=None)
|
op=None)
|
||||||
out_ip = (out_ip.unsqueeze(0).reshape(
|
out_ip = (out_ip.unsqueeze(0).reshape(
|
||||||
b, self.heads, out_ip.shape[1],
|
b, h, out_ip.shape[1],
|
||||||
self.dim_head).permute(0, 2, 1,
|
self.dim_head).permute(0, 2, 1,
|
||||||
3).reshape(b, out_ip.shape[1],
|
3).reshape(b, out_ip.shape[1],
|
||||||
self.heads * self.dim_head))
|
h * self.dim_head))
|
||||||
|
|
||||||
if k_as is not None:
|
if k_as is not None:
|
||||||
# For agent state cross-attention
|
|
||||||
k_as, v_as = map(
|
|
||||||
lambda t: t.unsqueeze(3).reshape(b, t.shape[
|
|
||||||
1], self.heads, self.dim_head).permute(0, 2, 1, 3).reshape(
|
|
||||||
b * self.heads, t.shape[1], self.dim_head).contiguous(
|
|
||||||
),
|
|
||||||
(k_as, v_as),
|
|
||||||
)
|
|
||||||
out_as = xformers.ops.memory_efficient_attention(q,
|
out_as = xformers.ops.memory_efficient_attention(q,
|
||||||
k_as,
|
k_as,
|
||||||
v_as,
|
v_as,
|
||||||
attn_bias=None,
|
attn_bias=None,
|
||||||
op=None)
|
op=None)
|
||||||
out_as = (out_as.unsqueeze(0).reshape(
|
out_as = (out_as.unsqueeze(0).reshape(
|
||||||
b, self.heads, out_as.shape[1],
|
b, h, out_as.shape[1],
|
||||||
self.dim_head).permute(0, 2, 1,
|
self.dim_head).permute(0, 2, 1,
|
||||||
3).reshape(b, out_as.shape[1],
|
3).reshape(b, out_as.shape[1],
|
||||||
self.heads * self.dim_head))
|
h * self.dim_head))
|
||||||
|
|
||||||
if k_aa is not None:
|
if k_aa is not None:
|
||||||
# For agent action cross-attention
|
|
||||||
k_aa, v_aa = map(
|
|
||||||
lambda t: t.unsqueeze(3).reshape(b, t.shape[
|
|
||||||
1], self.heads, self.dim_head).permute(0, 2, 1, 3).reshape(
|
|
||||||
b * self.heads, t.shape[1], self.dim_head).contiguous(
|
|
||||||
),
|
|
||||||
(k_aa, v_aa),
|
|
||||||
)
|
|
||||||
|
|
||||||
attn_mask_aa = attn_mask_aa.unsqueeze(1).repeat(1,self.heads,1,1).reshape(
|
|
||||||
b * self.heads, attn_mask_aa.shape[1], attn_mask_aa.shape[2])
|
|
||||||
attn_mask_aa = attn_mask_aa.to(q.dtype)
|
|
||||||
|
|
||||||
out_aa = xformers.ops.memory_efficient_attention(
|
out_aa = xformers.ops.memory_efficient_attention(
|
||||||
q, k_aa, v_aa, attn_bias=attn_mask_aa, op=None)
|
q, k_aa, v_aa, attn_bias=attn_mask_aa, op=None)
|
||||||
|
|
||||||
out_aa = (out_aa.unsqueeze(0).reshape(
|
out_aa = (out_aa.unsqueeze(0).reshape(
|
||||||
b, self.heads, out_aa.shape[1],
|
b, h, out_aa.shape[1],
|
||||||
self.dim_head).permute(0, 2, 1,
|
self.dim_head).permute(0, 2, 1,
|
||||||
3).reshape(b, out_aa.shape[1],
|
3).reshape(b, out_aa.shape[1],
|
||||||
self.heads * self.dim_head))
|
h * self.dim_head))
|
||||||
if exists(mask):
|
if exists(mask):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@@ -543,7 +463,7 @@ class CrossAttention(nn.Module):
|
|||||||
col_indices = torch.arange(l2, device=target_device)
|
col_indices = torch.arange(l2, device=target_device)
|
||||||
mask_2d = col_indices.unsqueeze(0) >= start_positions.unsqueeze(1)
|
mask_2d = col_indices.unsqueeze(0) >= start_positions.unsqueeze(1)
|
||||||
mask = mask_2d.unsqueeze(1).expand(b, l1, l2)
|
mask = mask_2d.unsqueeze(1).expand(b, l1, l2)
|
||||||
attn_mask = torch.zeros(b, l1, l2, dtype=torch.bfloat16, device=target_device)
|
attn_mask = torch.zeros(b, l1, l2, dtype=torch.float, device=target_device)
|
||||||
attn_mask[mask] = float('-inf')
|
attn_mask[mask] = float('-inf')
|
||||||
|
|
||||||
self._attn_mask_aa_cache_key = cache_key
|
self._attn_mask_aa_cache_key = cache_key
|
||||||
|
|||||||
@@ -422,7 +422,7 @@ class WMAModel(nn.Module):
|
|||||||
self.temporal_attention = temporal_attention
|
self.temporal_attention = temporal_attention
|
||||||
time_embed_dim = model_channels * 4
|
time_embed_dim = model_channels * 4
|
||||||
self.use_checkpoint = use_checkpoint
|
self.use_checkpoint = use_checkpoint
|
||||||
self.dtype = torch.float16 if use_fp16 else torch.bfloat16
|
self.dtype = torch.float16 if use_fp16 else torch.float32
|
||||||
temporal_self_att_only = True
|
temporal_self_att_only = True
|
||||||
self.addition_attention = addition_attention
|
self.addition_attention = addition_attention
|
||||||
self.temporal_length = temporal_length
|
self.temporal_length = temporal_length
|
||||||
@@ -688,8 +688,17 @@ class WMAModel(nn.Module):
|
|||||||
# Context precomputation cache
|
# Context precomputation cache
|
||||||
self._ctx_cache_enabled = False
|
self._ctx_cache_enabled = False
|
||||||
self._ctx_cache = {}
|
self._ctx_cache = {}
|
||||||
# fs_embed cache
|
# Reusable CUDA stream for parallel state_unet / action_unet
|
||||||
self._fs_embed_cache = None
|
self._state_stream = torch.cuda.Stream()
|
||||||
|
|
||||||
|
def __getstate__(self):
|
||||||
|
state = self.__dict__.copy()
|
||||||
|
state.pop('_state_stream', None)
|
||||||
|
return state
|
||||||
|
|
||||||
|
def __setstate__(self, state):
|
||||||
|
self.__dict__.update(state)
|
||||||
|
self._state_stream = torch.cuda.Stream()
|
||||||
|
|
||||||
def forward(self,
|
def forward(self,
|
||||||
x: Tensor,
|
x: Tensor,
|
||||||
@@ -791,20 +800,16 @@ class WMAModel(nn.Module):
|
|||||||
|
|
||||||
# Combine emb
|
# Combine emb
|
||||||
if self.fs_condition:
|
if self.fs_condition:
|
||||||
if self._ctx_cache_enabled and self._fs_embed_cache is not None:
|
if fs is None:
|
||||||
fs_embed = self._fs_embed_cache
|
fs = torch.tensor([self.default_fs] * b,
|
||||||
else:
|
dtype=torch.long,
|
||||||
if fs is None:
|
device=x.device)
|
||||||
fs = torch.tensor([self.default_fs] * b,
|
fs_emb = timestep_embedding(fs,
|
||||||
dtype=torch.long,
|
self.model_channels,
|
||||||
device=x.device)
|
repeat_only=False).type(x.dtype)
|
||||||
fs_emb = timestep_embedding(fs,
|
|
||||||
self.model_channels,
|
fs_embed = self.fps_embedding(fs_emb)
|
||||||
repeat_only=False).type(x.dtype)
|
fs_embed = fs_embed.repeat_interleave(repeats=t, dim=0)
|
||||||
fs_embed = self.fps_embedding(fs_emb)
|
|
||||||
fs_embed = fs_embed.repeat_interleave(repeats=t, dim=0)
|
|
||||||
if self._ctx_cache_enabled:
|
|
||||||
self._fs_embed_cache = fs_embed
|
|
||||||
emb = emb + fs_embed
|
emb = emb + fs_embed
|
||||||
|
|
||||||
h = x.type(self.dtype)
|
h = x.type(self.dtype)
|
||||||
@@ -848,15 +853,16 @@ class WMAModel(nn.Module):
|
|||||||
|
|
||||||
if not self.base_model_gen_only:
|
if not self.base_model_gen_only:
|
||||||
ba, _, _ = x_action.shape
|
ba, _, _ = x_action.shape
|
||||||
|
ts_state = timesteps[:ba] if b > 1 else timesteps
|
||||||
|
# Run action_unet and state_unet in parallel via CUDA streams
|
||||||
|
s_stream = self._state_stream
|
||||||
|
s_stream.wait_stream(torch.cuda.current_stream())
|
||||||
|
with torch.cuda.stream(s_stream):
|
||||||
|
s_y = self.state_unet(x_state, ts_state, hs_a,
|
||||||
|
context_action[:2], **kwargs)
|
||||||
a_y = self.action_unet(x_action, timesteps[:ba], hs_a,
|
a_y = self.action_unet(x_action, timesteps[:ba], hs_a,
|
||||||
context_action[:2], **kwargs)
|
context_action[:2], **kwargs)
|
||||||
# Predict state
|
torch.cuda.current_stream().wait_stream(s_stream)
|
||||||
if b > 1:
|
|
||||||
s_y = self.state_unet(x_state, timesteps[:ba], hs_a,
|
|
||||||
context_action[:2], **kwargs)
|
|
||||||
else:
|
|
||||||
s_y = self.state_unet(x_state, timesteps, hs_a,
|
|
||||||
context_action[:2], **kwargs)
|
|
||||||
else:
|
else:
|
||||||
a_y = torch.zeros_like(x_action)
|
a_y = torch.zeros_like(x_action)
|
||||||
s_y = torch.zeros_like(x_state)
|
s_y = torch.zeros_like(x_state)
|
||||||
@@ -870,7 +876,6 @@ def enable_ctx_cache(model):
|
|||||||
if isinstance(m, WMAModel):
|
if isinstance(m, WMAModel):
|
||||||
m._ctx_cache_enabled = True
|
m._ctx_cache_enabled = True
|
||||||
m._ctx_cache = {}
|
m._ctx_cache = {}
|
||||||
m._fs_embed_cache = None
|
|
||||||
# conditional_unet1d cache
|
# conditional_unet1d cache
|
||||||
from unifolm_wma.models.diffusion_head.conditional_unet1d import ConditionalUnet1D
|
from unifolm_wma.models.diffusion_head.conditional_unet1d import ConditionalUnet1D
|
||||||
for m in model.modules():
|
for m in model.modules():
|
||||||
@@ -885,7 +890,6 @@ def disable_ctx_cache(model):
|
|||||||
if isinstance(m, WMAModel):
|
if isinstance(m, WMAModel):
|
||||||
m._ctx_cache_enabled = False
|
m._ctx_cache_enabled = False
|
||||||
m._ctx_cache = {}
|
m._ctx_cache = {}
|
||||||
m._fs_embed_cache = None
|
|
||||||
from unifolm_wma.models.diffusion_head.conditional_unet1d import ConditionalUnet1D
|
from unifolm_wma.models.diffusion_head.conditional_unet1d import ConditionalUnet1D
|
||||||
for m in model.modules():
|
for m in model.modules():
|
||||||
if isinstance(m, ConditionalUnet1D):
|
if isinstance(m, ConditionalUnet1D):
|
||||||
|
|||||||
@@ -7,9 +7,7 @@
|
|||||||
#
|
#
|
||||||
# thanks!
|
# thanks!
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
|
||||||
from unifolm_wma.utils.utils import instantiate_from_config
|
from unifolm_wma.utils.utils import instantiate_from_config
|
||||||
|
|
||||||
|
|
||||||
@@ -80,11 +78,7 @@ def nonlinearity(type='silu'):
|
|||||||
class GroupNormSpecific(nn.GroupNorm):
|
class GroupNormSpecific(nn.GroupNorm):
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
with torch.amp.autocast('cuda', enabled=False):
|
return super().forward(x.float()).type(x.dtype)
|
||||||
return F.group_norm(x, self.num_groups,
|
|
||||||
self.weight.to(x.dtype) if self.weight is not None else None,
|
|
||||||
self.bias.to(x.dtype) if self.bias is not None else None,
|
|
||||||
self.eps)
|
|
||||||
|
|
||||||
|
|
||||||
def normalization(channels, num_groups=32):
|
def normalization(channels, num_groups=32):
|
||||||
|
|||||||
@@ -1,32 +1,16 @@
|
|||||||
2026-02-08 05:20:49.828675: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
2026-02-19 18:55:32.160020: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
||||||
2026-02-08 05:20:49.831563: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
2026-02-19 18:55:32.207538: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
||||||
2026-02-08 05:20:49.861366: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
2026-02-19 18:55:32.207581: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
||||||
2026-02-08 05:20:49.861402: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
2026-02-19 18:55:32.208613: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
||||||
2026-02-08 05:20:49.862974: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
2026-02-19 18:55:32.215249: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
||||||
2026-02-08 05:20:49.870402: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
||||||
2026-02-08 05:20:49.870647: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
2026-02-19 18:55:33.121466: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
||||||
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
|
||||||
2026-02-08 05:20:50.486843: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
|
||||||
Global seed set to 123
|
Global seed set to 123
|
||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
|
>>> Loading prepared model from ckpts/unifolm_wma_dual.ckpt.prepared.pt ...
|
||||||
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
|
>>> Prepared model loaded.
|
||||||
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
|
>>> Diffusion backbone (model.model) converted to FP16.
|
||||||
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
>>> Projectors (image_proj_model, state_projector, action_projector) converted to FP16.
|
||||||
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
>>> Encoders (cond_stage_model, embedder) converted to FP16.
|
||||||
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
|
|
||||||
INFO:root:Loaded ViT-H-14 model config.
|
|
||||||
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
|
|
||||||
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
|
||||||
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
|
||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
|
||||||
checkpoint = torch.load(checkpoint_path, map_location=map_location)
|
|
||||||
INFO:root:Loaded ViT-H-14 model config.
|
|
||||||
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
|
||||||
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
|
||||||
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:86: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
|
||||||
state_dict = torch.load(ckpt, map_location="cpu")
|
|
||||||
>>> model checkpoint loaded.
|
|
||||||
>>> Load pre-trained model ...
|
|
||||||
INFO:root:***** Configing Data *****
|
INFO:root:***** Configing Data *****
|
||||||
>>> unitree_z1_stackbox: 1 data samples loaded.
|
>>> unitree_z1_stackbox: 1 data samples loaded.
|
||||||
>>> unitree_z1_stackbox: data stats loaded.
|
>>> unitree_z1_stackbox: data stats loaded.
|
||||||
@@ -44,69 +28,15 @@ INFO:root:***** Configing Data *****
|
|||||||
>>> unitree_g1_pack_camera: data stats loaded.
|
>>> unitree_g1_pack_camera: data stats loaded.
|
||||||
>>> unitree_g1_pack_camera: normalizer initiated.
|
>>> unitree_g1_pack_camera: normalizer initiated.
|
||||||
>>> Dataset is successfully loaded ...
|
>>> Dataset is successfully loaded ...
|
||||||
|
✓ KV fused: 66 attention layers
|
||||||
>>> Generate 16 frames under each generation ...
|
>>> Generate 16 frames under each generation ...
|
||||||
DEBUG:h5py._conv:Creating converter from 3 to 5
|
DEBUG:h5py._conv:Creating converter from 3 to 5
|
||||||
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
|
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
|
||||||
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
|
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
|
||||||
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
||||||
|
|
||||||
0%| | 0/11 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
|
|
||||||
proj = linear(q, w, b)
|
|
||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Flash attention support on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:225.)
|
|
||||||
attn_output = scaled_dot_product_attention(
|
|
||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)
|
|
||||||
0%| | 0/11 [00:00<?, ?it/s]
|
0%| | 0/11 [00:00<?, ?it/s]
|
||||||
9%|▉ | 1/11 [00:23<03:56, 23.63s/it]
|
9%|▉ | 1/11 [00:23<03:56, 23.63s/it]
|
||||||
>>> Step 0: interacting with world model ...
|
|
||||||
DEBUG:PIL.Image:Importing BlpImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing BmpImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing BufrStubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing CurImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing DcxImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing DdsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing EpsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FitsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FitsStubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FliImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FpxImagePlugin
|
|
||||||
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
|
|
||||||
DEBUG:PIL.Image:Importing FtexImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing GbrImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing GifImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing GribStubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing IcnsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing IcoImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing ImImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing ImtImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing IptcImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing JpegImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing McIdasImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing MicImagePlugin
|
|
||||||
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
|
|
||||||
DEBUG:PIL.Image:Importing MpegImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing MpoImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing MspImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PalmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PcdImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PcxImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PdfImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PixarImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PngImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PpmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PsdImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing QoiImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing SgiImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing SpiderImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing SunImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing TgaImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing TiffImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing WebPImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing WmfImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing XbmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing XpmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
|
||||||
18%|█▊ | 2/11 [00:46<03:29, 23.24s/it]
|
18%|█▊ | 2/11 [00:46<03:29, 23.24s/it]
|
||||||
27%|██▋ | 3/11 [01:09<03:05, 23.25s/it]
|
27%|██▋ | 3/11 [01:09<03:05, 23.25s/it]
|
||||||
36%|███▋ | 4/11 [01:33<02:43, 23.31s/it]
|
36%|███▋ | 4/11 [01:33<02:43, 23.31s/it]
|
||||||
@@ -139,6 +69,6 @@ DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
|||||||
>>> Step 6: generating actions ...
|
>>> Step 6: generating actions ...
|
||||||
>>> Step 6: interacting with world model ...
|
>>> Step 6: interacting with world model ...
|
||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
>>> Step 7: generating actions ...
|
>>> Step 7: generating actions ...
|
||||||
>>> Step 7: interacting with world model ...
|
>>> Step 7: interacting with world model ...
|
||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
{
|
{
|
||||||
"gt_video": "unitree_g1_pack_camera/case1/unitree_g1_pack_camera_case1.mp4",
|
"gt_video": "unitree_g1_pack_camera/case1/unitree_g1_pack_camera_case1.mp4",
|
||||||
"pred_video": "unitree_g1_pack_camera/case1/output/inference/unitree_g1_pack_camera_case1_amd.mp4",
|
"pred_video": "unitree_g1_pack_camera/case1/output/inference/0_full_fs6.mp4",
|
||||||
"psnr": 16.415668383379177
|
"psnr": 32.34126103448495
|
||||||
}
|
}
|
||||||
@@ -20,5 +20,6 @@ dataset="unitree_g1_pack_camera"
|
|||||||
--n_iter 11 \
|
--n_iter 11 \
|
||||||
--timestep_spacing 'uniform_trailing' \
|
--timestep_spacing 'uniform_trailing' \
|
||||||
--guidance_rescale 0.7 \
|
--guidance_rescale 0.7 \
|
||||||
--perframe_ae
|
--perframe_ae \
|
||||||
|
--fast_policy_no_decode
|
||||||
} 2>&1 | tee "${res_dir}/output.log"
|
} 2>&1 | tee "${res_dir}/output.log"
|
||||||
|
|||||||
@@ -1,46 +1,16 @@
|
|||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
|
2026-02-19 19:00:05.944067: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
||||||
__import__("pkg_resources").declare_namespace(__name__)
|
2026-02-19 19:00:05.991354: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
||||||
2026-02-08 18:28:48.960238: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
2026-02-19 19:00:05.991392: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
||||||
2026-02-08 18:28:48.963331: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
2026-02-19 19:00:05.992414: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
||||||
2026-02-08 18:28:48.995688: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
2026-02-19 19:00:05.999050: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
||||||
2026-02-08 18:28:48.995732: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
||||||
2026-02-08 18:28:48.997547: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
2026-02-19 19:00:06.916175: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
||||||
2026-02-08 18:28:49.005673: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
Global seed set to 123
|
||||||
2026-02-08 18:28:49.005948: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
>>> Loading prepared model from ckpts/unifolm_wma_dual.ckpt.prepared.pt ...
|
||||||
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
>>> Prepared model loaded.
|
||||||
2026-02-08 18:28:50.009660: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
>>> Diffusion backbone (model.model) converted to FP16.
|
||||||
[rank: 0] Global seed set to 123
|
>>> Projectors (image_proj_model, state_projector, action_projector) converted to FP16.
|
||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
|
>>> Encoders (cond_stage_model, embedder) converted to FP16.
|
||||||
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
|
|
||||||
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
|
|
||||||
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
|
||||||
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
|
||||||
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
|
|
||||||
INFO:root:Loaded ViT-H-14 model config.
|
|
||||||
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
|
|
||||||
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
|
||||||
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
|
||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
|
||||||
checkpoint = torch.load(checkpoint_path, map_location=map_location)
|
|
||||||
INFO:root:Loaded ViT-H-14 model config.
|
|
||||||
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
|
||||||
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
|
||||||
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:198: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
|
||||||
state_dict = torch.load(ckpt, map_location="cpu")
|
|
||||||
>>> model checkpoint loaded.
|
|
||||||
>>> Load pre-trained model ...
|
|
||||||
>>> Applying precision settings:
|
|
||||||
- Diffusion dtype: bf16
|
|
||||||
- Projector mode: bf16_full
|
|
||||||
- Encoder mode: bf16_full
|
|
||||||
- VAE dtype: fp32
|
|
||||||
✓ Diffusion model weights converted to bfloat16
|
|
||||||
✓ Projectors converted to bfloat16
|
|
||||||
✓ Encoders converted to bfloat16
|
|
||||||
✓ VAE kept in fp32 for best quality
|
|
||||||
⚠ Found 849 fp32 params, converting to bf16
|
|
||||||
✓ All parameters converted to bfloat16
|
|
||||||
✓ torch.compile: 3 ResBlocks in output_blocks[5, 8, 9]
|
|
||||||
INFO:root:***** Configing Data *****
|
INFO:root:***** Configing Data *****
|
||||||
>>> unitree_z1_stackbox: 1 data samples loaded.
|
>>> unitree_z1_stackbox: 1 data samples loaded.
|
||||||
>>> unitree_z1_stackbox: data stats loaded.
|
>>> unitree_z1_stackbox: data stats loaded.
|
||||||
@@ -58,69 +28,15 @@ INFO:root:***** Configing Data *****
|
|||||||
>>> unitree_g1_pack_camera: data stats loaded.
|
>>> unitree_g1_pack_camera: data stats loaded.
|
||||||
>>> unitree_g1_pack_camera: normalizer initiated.
|
>>> unitree_g1_pack_camera: normalizer initiated.
|
||||||
>>> Dataset is successfully loaded ...
|
>>> Dataset is successfully loaded ...
|
||||||
|
✓ KV fused: 66 attention layers
|
||||||
>>> Generate 16 frames under each generation ...
|
>>> Generate 16 frames under each generation ...
|
||||||
DEBUG:h5py._conv:Creating converter from 3 to 5
|
DEBUG:h5py._conv:Creating converter from 3 to 5
|
||||||
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
|
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
|
||||||
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
|
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
|
||||||
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
||||||
|
|
||||||
0%| | 0/11 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
|
|
||||||
proj = linear(q, w, b)
|
|
||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Flash attention support on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:225.)
|
|
||||||
attn_output = scaled_dot_product_attention(
|
|
||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)
|
|
||||||
0%| | 0/11 [00:00<?, ?it/s]
|
0%| | 0/11 [00:00<?, ?it/s]
|
||||||
9%|▉ | 1/11 [00:24<04:00, 24.04s/it]
|
9%|▉ | 1/11 [00:24<04:00, 24.04s/it]
|
||||||
>>> Step 0: interacting with world model ...
|
|
||||||
DEBUG:PIL.Image:Importing BlpImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing BmpImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing BufrStubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing CurImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing DcxImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing DdsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing EpsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FitsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FitsStubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FliImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FpxImagePlugin
|
|
||||||
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
|
|
||||||
DEBUG:PIL.Image:Importing FtexImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing GbrImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing GifImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing GribStubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing IcnsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing IcoImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing ImImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing ImtImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing IptcImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing JpegImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing McIdasImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing MicImagePlugin
|
|
||||||
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
|
|
||||||
DEBUG:PIL.Image:Importing MpegImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing MpoImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing MspImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PalmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PcdImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PcxImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PdfImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PixarImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PngImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PpmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PsdImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing QoiImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing SgiImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing SpiderImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing SunImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing TgaImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing TiffImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing WebPImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing WmfImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing XbmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing XpmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
|
||||||
18%|█▊ | 2/11 [00:47<03:31, 23.55s/it]
|
18%|█▊ | 2/11 [00:47<03:31, 23.55s/it]
|
||||||
27%|██▋ | 3/11 [01:10<03:07, 23.43s/it]
|
27%|██▋ | 3/11 [01:10<03:07, 23.43s/it]
|
||||||
36%|███▋ | 4/11 [01:33<02:43, 23.42s/it]
|
36%|███▋ | 4/11 [01:33<02:43, 23.42s/it]
|
||||||
@@ -153,6 +69,6 @@ DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
|||||||
>>> Step 6: generating actions ...
|
>>> Step 6: generating actions ...
|
||||||
>>> Step 6: interacting with world model ...
|
>>> Step 6: interacting with world model ...
|
||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
>>> Step 7: generating actions ...
|
>>> Step 7: generating actions ...
|
||||||
>>> Step 7: interacting with world model ...
|
>>> Step 7: interacting with world model ...
|
||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
{
|
{
|
||||||
"gt_video": "unitree_g1_pack_camera/case2/unitree_g1_pack_camera_case2.mp4",
|
"gt_video": "unitree_g1_pack_camera/case2/unitree_g1_pack_camera_case2.mp4",
|
||||||
"pred_video": "unitree_g1_pack_camera/case2/output/inference/unitree_g1_pack_camera_case2_amd.mp4",
|
"pred_video": "unitree_g1_pack_camera/case2/output/inference/50_full_fs6.mp4",
|
||||||
"psnr": 19.515250190529375
|
"psnr": 37.49178506869336
|
||||||
}
|
}
|
||||||
@@ -20,5 +20,6 @@ dataset="unitree_g1_pack_camera"
|
|||||||
--n_iter 11 \
|
--n_iter 11 \
|
||||||
--timestep_spacing 'uniform_trailing' \
|
--timestep_spacing 'uniform_trailing' \
|
||||||
--guidance_rescale 0.7 \
|
--guidance_rescale 0.7 \
|
||||||
--perframe_ae
|
--perframe_ae \
|
||||||
|
--fast_policy_no_decode
|
||||||
} 2>&1 | tee "${res_dir}/output.log"
|
} 2>&1 | tee "${res_dir}/output.log"
|
||||||
|
|||||||
@@ -1,32 +1,16 @@
|
|||||||
2026-02-08 05:08:32.803904: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
2026-02-19 19:04:41.036634: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
||||||
2026-02-08 05:08:32.807010: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
2026-02-19 19:04:41.084414: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
||||||
2026-02-08 05:08:32.837936: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
2026-02-19 19:04:41.084452: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
||||||
2026-02-08 05:08:32.837978: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
2026-02-19 19:04:41.085481: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
||||||
2026-02-08 05:08:32.839785: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
2026-02-19 19:04:41.092287: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
||||||
2026-02-08 05:08:32.847835: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
||||||
2026-02-08 05:08:32.848223: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
2026-02-19 19:04:42.000614: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
||||||
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
Global seed set to 123
|
||||||
2026-02-08 05:08:34.120114: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
>>> Loading prepared model from ckpts/unifolm_wma_dual.ckpt.prepared.pt ...
|
||||||
[rank: 0] Global seed set to 123
|
>>> Prepared model loaded.
|
||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
|
>>> Diffusion backbone (model.model) converted to FP16.
|
||||||
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
|
>>> Projectors (image_proj_model, state_projector, action_projector) converted to FP16.
|
||||||
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
|
>>> Encoders (cond_stage_model, embedder) converted to FP16.
|
||||||
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
|
||||||
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
|
||||||
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
|
|
||||||
INFO:root:Loaded ViT-H-14 model config.
|
|
||||||
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
|
|
||||||
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
|
||||||
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
|
||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
|
||||||
checkpoint = torch.load(checkpoint_path, map_location=map_location)
|
|
||||||
INFO:root:Loaded ViT-H-14 model config.
|
|
||||||
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
|
||||||
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
|
||||||
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:86: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
|
||||||
state_dict = torch.load(ckpt, map_location="cpu")
|
|
||||||
>>> model checkpoint loaded.
|
|
||||||
>>> Load pre-trained model ...
|
|
||||||
INFO:root:***** Configing Data *****
|
INFO:root:***** Configing Data *****
|
||||||
>>> unitree_z1_stackbox: 1 data samples loaded.
|
>>> unitree_z1_stackbox: 1 data samples loaded.
|
||||||
>>> unitree_z1_stackbox: data stats loaded.
|
>>> unitree_z1_stackbox: data stats loaded.
|
||||||
@@ -44,69 +28,15 @@ INFO:root:***** Configing Data *****
|
|||||||
>>> unitree_g1_pack_camera: data stats loaded.
|
>>> unitree_g1_pack_camera: data stats loaded.
|
||||||
>>> unitree_g1_pack_camera: normalizer initiated.
|
>>> unitree_g1_pack_camera: normalizer initiated.
|
||||||
>>> Dataset is successfully loaded ...
|
>>> Dataset is successfully loaded ...
|
||||||
|
✓ KV fused: 66 attention layers
|
||||||
>>> Generate 16 frames under each generation ...
|
>>> Generate 16 frames under each generation ...
|
||||||
DEBUG:h5py._conv:Creating converter from 3 to 5
|
DEBUG:h5py._conv:Creating converter from 3 to 5
|
||||||
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
|
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
|
||||||
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
|
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
|
||||||
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
||||||
|
|
||||||
0%| | 0/11 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
|
|
||||||
proj = linear(q, w, b)
|
|
||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Flash attention support on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:225.)
|
|
||||||
attn_output = scaled_dot_product_attention(
|
|
||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)
|
|
||||||
0%| | 0/11 [00:00<?, ?it/s]
|
0%| | 0/11 [00:00<?, ?it/s]
|
||||||
9%|▉ | 1/11 [00:24<04:01, 24.19s/it]
|
9%|▉ | 1/11 [00:24<04:01, 24.19s/it]
|
||||||
>>> Step 0: interacting with world model ...
|
|
||||||
DEBUG:PIL.Image:Importing BlpImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing BmpImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing BufrStubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing CurImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing DcxImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing DdsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing EpsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FitsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FitsStubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FliImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FpxImagePlugin
|
|
||||||
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
|
|
||||||
DEBUG:PIL.Image:Importing FtexImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing GbrImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing GifImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing GribStubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing IcnsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing IcoImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing ImImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing ImtImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing IptcImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing JpegImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing McIdasImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing MicImagePlugin
|
|
||||||
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
|
|
||||||
DEBUG:PIL.Image:Importing MpegImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing MpoImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing MspImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PalmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PcdImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PcxImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PdfImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PixarImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PngImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PpmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PsdImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing QoiImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing SgiImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing SpiderImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing SunImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing TgaImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing TiffImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing WebPImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing WmfImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing XbmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing XpmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
|
||||||
18%|█▊ | 2/11 [00:47<03:32, 23.64s/it]
|
18%|█▊ | 2/11 [00:47<03:32, 23.64s/it]
|
||||||
27%|██▋ | 3/11 [01:10<03:08, 23.50s/it]
|
27%|██▋ | 3/11 [01:10<03:08, 23.50s/it]
|
||||||
36%|███▋ | 4/11 [01:34<02:44, 23.47s/it]
|
36%|███▋ | 4/11 [01:34<02:44, 23.47s/it]
|
||||||
@@ -139,6 +69,6 @@ DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
|||||||
>>> Step 6: generating actions ...
|
>>> Step 6: generating actions ...
|
||||||
>>> Step 6: interacting with world model ...
|
>>> Step 6: interacting with world model ...
|
||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
>>> Step 7: generating actions ...
|
>>> Step 7: generating actions ...
|
||||||
>>> Step 7: interacting with world model ...
|
>>> Step 7: interacting with world model ...
|
||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
{
|
{
|
||||||
"gt_video": "unitree_g1_pack_camera/case3/unitree_g1_pack_camera_case3.mp4",
|
"gt_video": "unitree_g1_pack_camera/case3/unitree_g1_pack_camera_case3.mp4",
|
||||||
"pred_video": "unitree_g1_pack_camera/case3/output/inference/unitree_g1_pack_camera_case3_amd.mp4",
|
"pred_video": "unitree_g1_pack_camera/case3/output/inference/100_full_fs6.mp4",
|
||||||
"psnr": 19.429578160315536
|
"psnr": 29.88155122131729
|
||||||
}
|
}
|
||||||
@@ -20,5 +20,6 @@ dataset="unitree_g1_pack_camera"
|
|||||||
--n_iter 11 \
|
--n_iter 11 \
|
||||||
--timestep_spacing 'uniform_trailing' \
|
--timestep_spacing 'uniform_trailing' \
|
||||||
--guidance_rescale 0.7 \
|
--guidance_rescale 0.7 \
|
||||||
--perframe_ae
|
--perframe_ae \
|
||||||
|
--fast_policy_no_decode
|
||||||
} 2>&1 | tee "${res_dir}/output.log"
|
} 2>&1 | tee "${res_dir}/output.log"
|
||||||
|
|||||||
@@ -1,32 +1,16 @@
|
|||||||
2026-02-08 05:29:19.728303: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
2026-02-19 19:09:16.122268: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
||||||
2026-02-08 05:29:19.731620: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
2026-02-19 19:09:16.170290: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
||||||
2026-02-08 05:29:19.761276: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
2026-02-19 19:09:16.170331: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
||||||
2026-02-08 05:29:19.761301: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
2026-02-19 19:09:16.171349: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
||||||
2026-02-08 05:29:19.762880: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
2026-02-19 19:09:16.177993: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
||||||
2026-02-08 05:29:19.770578: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
||||||
2026-02-08 05:29:19.771072: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
2026-02-19 19:09:17.087425: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
||||||
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
|
||||||
2026-02-08 05:29:21.043661: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
|
||||||
Global seed set to 123
|
Global seed set to 123
|
||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
|
>>> Loading prepared model from ckpts/unifolm_wma_dual.ckpt.prepared.pt ...
|
||||||
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
|
>>> Prepared model loaded.
|
||||||
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
|
>>> Diffusion backbone (model.model) converted to FP16.
|
||||||
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
>>> Projectors (image_proj_model, state_projector, action_projector) converted to FP16.
|
||||||
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
>>> Encoders (cond_stage_model, embedder) converted to FP16.
|
||||||
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
|
|
||||||
INFO:root:Loaded ViT-H-14 model config.
|
|
||||||
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
|
|
||||||
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
|
||||||
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
|
||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
|
||||||
checkpoint = torch.load(checkpoint_path, map_location=map_location)
|
|
||||||
INFO:root:Loaded ViT-H-14 model config.
|
|
||||||
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
|
||||||
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
|
||||||
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:86: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
|
||||||
state_dict = torch.load(ckpt, map_location="cpu")
|
|
||||||
>>> model checkpoint loaded.
|
|
||||||
>>> Load pre-trained model ...
|
|
||||||
INFO:root:***** Configing Data *****
|
INFO:root:***** Configing Data *****
|
||||||
>>> unitree_z1_stackbox: 1 data samples loaded.
|
>>> unitree_z1_stackbox: 1 data samples loaded.
|
||||||
>>> unitree_z1_stackbox: data stats loaded.
|
>>> unitree_z1_stackbox: data stats loaded.
|
||||||
@@ -44,69 +28,15 @@ INFO:root:***** Configing Data *****
|
|||||||
>>> unitree_g1_pack_camera: data stats loaded.
|
>>> unitree_g1_pack_camera: data stats loaded.
|
||||||
>>> unitree_g1_pack_camera: normalizer initiated.
|
>>> unitree_g1_pack_camera: normalizer initiated.
|
||||||
>>> Dataset is successfully loaded ...
|
>>> Dataset is successfully loaded ...
|
||||||
|
✓ KV fused: 66 attention layers
|
||||||
>>> Generate 16 frames under each generation ...
|
>>> Generate 16 frames under each generation ...
|
||||||
DEBUG:h5py._conv:Creating converter from 3 to 5
|
DEBUG:h5py._conv:Creating converter from 3 to 5
|
||||||
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
|
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
|
||||||
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
|
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
|
||||||
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
||||||
|
|
||||||
0%| | 0/11 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
|
|
||||||
proj = linear(q, w, b)
|
|
||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Flash attention support on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:225.)
|
|
||||||
attn_output = scaled_dot_product_attention(
|
|
||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)
|
|
||||||
0%| | 0/11 [00:00<?, ?it/s]
|
0%| | 0/11 [00:00<?, ?it/s]
|
||||||
9%|▉ | 1/11 [00:24<04:01, 24.17s/it]
|
9%|▉ | 1/11 [00:24<04:01, 24.17s/it]
|
||||||
>>> Step 0: interacting with world model ...
|
|
||||||
DEBUG:PIL.Image:Importing BlpImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing BmpImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing BufrStubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing CurImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing DcxImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing DdsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing EpsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FitsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FitsStubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FliImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FpxImagePlugin
|
|
||||||
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
|
|
||||||
DEBUG:PIL.Image:Importing FtexImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing GbrImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing GifImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing GribStubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing IcnsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing IcoImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing ImImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing ImtImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing IptcImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing JpegImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing McIdasImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing MicImagePlugin
|
|
||||||
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
|
|
||||||
DEBUG:PIL.Image:Importing MpegImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing MpoImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing MspImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PalmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PcdImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PcxImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PdfImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PixarImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PngImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PpmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PsdImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing QoiImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing SgiImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing SpiderImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing SunImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing TgaImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing TiffImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing WebPImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing WmfImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing XbmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing XpmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
|
||||||
18%|█▊ | 2/11 [00:47<03:32, 23.62s/it]
|
18%|█▊ | 2/11 [00:47<03:32, 23.62s/it]
|
||||||
27%|██▋ | 3/11 [01:10<03:07, 23.49s/it]
|
27%|██▋ | 3/11 [01:10<03:07, 23.49s/it]
|
||||||
36%|███▋ | 4/11 [01:34<02:44, 23.46s/it]
|
36%|███▋ | 4/11 [01:34<02:44, 23.46s/it]
|
||||||
@@ -139,6 +69,6 @@ DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
|||||||
>>> Step 6: generating actions ...
|
>>> Step 6: generating actions ...
|
||||||
>>> Step 6: interacting with world model ...
|
>>> Step 6: interacting with world model ...
|
||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
>>> Step 7: generating actions ...
|
>>> Step 7: generating actions ...
|
||||||
>>> Step 7: interacting with world model ...
|
>>> Step 7: interacting with world model ...
|
||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
{
|
{
|
||||||
"gt_video": "unitree_g1_pack_camera/case4/unitree_g1_pack_camera_case4.mp4",
|
"gt_video": "unitree_g1_pack_camera/case4/unitree_g1_pack_camera_case4.mp4",
|
||||||
"pred_video": "unitree_g1_pack_camera/case4/output/inference/unitree_g1_pack_camera_case4_amd.mp4",
|
"pred_video": "unitree_g1_pack_camera/case4/output/inference/200_full_fs6.mp4",
|
||||||
"psnr": 17.80386833747375
|
"psnr": 35.62512454155058
|
||||||
}
|
}
|
||||||
@@ -20,5 +20,6 @@ dataset="unitree_g1_pack_camera"
|
|||||||
--n_iter 11 \
|
--n_iter 11 \
|
||||||
--timestep_spacing 'uniform_trailing' \
|
--timestep_spacing 'uniform_trailing' \
|
||||||
--guidance_rescale 0.7 \
|
--guidance_rescale 0.7 \
|
||||||
--perframe_ae
|
--perframe_ae \
|
||||||
|
--fast_policy_no_decode
|
||||||
} 2>&1 | tee "${res_dir}/output.log"
|
} 2>&1 | tee "${res_dir}/output.log"
|
||||||
|
|||||||
@@ -1,46 +1,16 @@
|
|||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
|
2026-02-19 19:13:51.554194: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
||||||
__import__("pkg_resources").declare_namespace(__name__)
|
2026-02-19 19:13:51.601580: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
||||||
2026-02-09 18:39:50.119842: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
2026-02-19 19:13:51.601622: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
||||||
2026-02-09 18:39:50.123128: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
2026-02-19 19:13:51.602646: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
||||||
2026-02-09 18:39:50.156652: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
2026-02-19 19:13:51.609297: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
||||||
2026-02-09 18:39:50.156708: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
||||||
2026-02-09 18:39:50.158926: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
2026-02-19 19:13:52.517676: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
||||||
2026-02-09 18:39:50.167779: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
Global seed set to 123
|
||||||
2026-02-09 18:39:50.168073: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
>>> Loading prepared model from ckpts/unifolm_wma_dual.ckpt.prepared.pt ...
|
||||||
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
>>> Prepared model loaded.
|
||||||
2026-02-09 18:39:50.915144: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
>>> Diffusion backbone (model.model) converted to FP16.
|
||||||
[rank: 0] Global seed set to 123
|
>>> Projectors (image_proj_model, state_projector, action_projector) converted to FP16.
|
||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
|
>>> Encoders (cond_stage_model, embedder) converted to FP16.
|
||||||
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
|
|
||||||
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
|
|
||||||
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
|
||||||
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
|
||||||
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
|
|
||||||
INFO:root:Loaded ViT-H-14 model config.
|
|
||||||
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
|
|
||||||
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
|
||||||
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
|
||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
|
||||||
checkpoint = torch.load(checkpoint_path, map_location=map_location)
|
|
||||||
INFO:root:Loaded ViT-H-14 model config.
|
|
||||||
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
|
||||||
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
|
||||||
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:198: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
|
||||||
state_dict = torch.load(ckpt, map_location="cpu")
|
|
||||||
>>> model checkpoint loaded.
|
|
||||||
>>> Load pre-trained model ...
|
|
||||||
>>> Applying precision settings:
|
|
||||||
- Diffusion dtype: bf16
|
|
||||||
- Projector mode: bf16_full
|
|
||||||
- Encoder mode: bf16_full
|
|
||||||
- VAE dtype: bf16
|
|
||||||
✓ Diffusion model weights converted to bfloat16
|
|
||||||
✓ Projectors converted to bfloat16
|
|
||||||
✓ Encoders converted to bfloat16
|
|
||||||
✓ VAE converted to bfloat16
|
|
||||||
⚠ Found 601 fp32 params, converting to bf16
|
|
||||||
✓ All parameters converted to bfloat16
|
|
||||||
✓ torch.compile: 3 ResBlocks in output_blocks[5, 8, 9]
|
|
||||||
INFO:root:***** Configing Data *****
|
INFO:root:***** Configing Data *****
|
||||||
>>> unitree_z1_stackbox: 1 data samples loaded.
|
>>> unitree_z1_stackbox: 1 data samples loaded.
|
||||||
>>> unitree_z1_stackbox: data stats loaded.
|
>>> unitree_z1_stackbox: data stats loaded.
|
||||||
@@ -58,65 +28,15 @@ INFO:root:***** Configing Data *****
|
|||||||
>>> unitree_g1_pack_camera: data stats loaded.
|
>>> unitree_g1_pack_camera: data stats loaded.
|
||||||
>>> unitree_g1_pack_camera: normalizer initiated.
|
>>> unitree_g1_pack_camera: normalizer initiated.
|
||||||
>>> Dataset is successfully loaded ...
|
>>> Dataset is successfully loaded ...
|
||||||
|
✓ KV fused: 66 attention layers
|
||||||
>>> Generate 16 frames under each generation ...
|
>>> Generate 16 frames under each generation ...
|
||||||
DEBUG:h5py._conv:Creating converter from 3 to 5
|
DEBUG:h5py._conv:Creating converter from 3 to 5
|
||||||
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
|
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
|
||||||
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
|
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
|
||||||
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
||||||
|
|
||||||
0%| | 0/8 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
|
|
||||||
0%| | 0/8 [00:00<?, ?it/s]
|
0%| | 0/8 [00:00<?, ?it/s]
|
||||||
12%|█▎ | 1/8 [00:24<02:49, 24.16s/it]
|
12%|█▎ | 1/8 [00:24<02:49, 24.16s/it]
|
||||||
>>> Step 0: interacting with world model ...
|
|
||||||
DEBUG:PIL.Image:Importing BlpImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing BmpImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing BufrStubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing CurImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing DcxImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing DdsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing EpsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FitsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FitsStubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FliImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FpxImagePlugin
|
|
||||||
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
|
|
||||||
DEBUG:PIL.Image:Importing FtexImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing GbrImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing GifImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing GribStubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing IcnsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing IcoImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing ImImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing ImtImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing IptcImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing JpegImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing McIdasImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing MicImagePlugin
|
|
||||||
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
|
|
||||||
DEBUG:PIL.Image:Importing MpegImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing MpoImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing MspImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PalmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PcdImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PcxImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PdfImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PixarImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PngImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PpmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PsdImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing QoiImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing SgiImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing SpiderImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing SunImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing TgaImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing TiffImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing WebPImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing WmfImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing XbmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing XpmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
|
||||||
25%|██▌ | 2/8 [00:47<02:22, 23.67s/it]
|
25%|██▌ | 2/8 [00:47<02:22, 23.67s/it]
|
||||||
38%|███▊ | 3/8 [01:10<01:57, 23.55s/it]
|
38%|███▊ | 3/8 [01:10<01:57, 23.55s/it]
|
||||||
50%|█████ | 4/8 [01:34<01:34, 23.51s/it]
|
50%|█████ | 4/8 [01:34<01:34, 23.51s/it]
|
||||||
@@ -140,6 +60,6 @@ DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
|||||||
>>> Step 4: generating actions ...
|
>>> Step 4: generating actions ...
|
||||||
>>> Step 4: interacting with world model ...
|
>>> Step 4: interacting with world model ...
|
||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
>>> Step 5: generating actions ...
|
>>> Step 5: generating actions ...
|
||||||
>>> Step 5: interacting with world model ...
|
>>> Step 5: interacting with world model ...
|
||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
{
|
{
|
||||||
"gt_video": "unitree_z1_dual_arm_cleanup_pencils/case1/unitree_z1_dual_arm_cleanup_pencils_case1.mp4",
|
"gt_video": "unitree_z1_dual_arm_cleanup_pencils/case1/unitree_z1_dual_arm_cleanup_pencils_case1.mp4",
|
||||||
"pred_video": "unitree_z1_dual_arm_cleanup_pencils/case1/output/inference/unitree_z1_dual_arm_cleanup_pencils_case1_amd.mp4",
|
"pred_video": "unitree_z1_dual_arm_cleanup_pencils/case1/output/inference/0_full_fs4.mp4",
|
||||||
"psnr": 19.586376345676264
|
"psnr": 38.269577028444445
|
||||||
}
|
}
|
||||||
@@ -1,5 +0,0 @@
|
|||||||
{
|
|
||||||
"gt_video": "/mnt/ASC1637/unifolm-world-model-action/unitree_z1_dual_arm_cleanup_pencils/case1/output/inference/unitree_z1_dual_arm_cleanup_pencils_case1_amd.mp4",
|
|
||||||
"pred_video": "/mnt/ASC1637/unifolm-world-model-action/unitree_z1_dual_arm_cleanup_pencils/case1/output/inference/0_full_fs4.mp4",
|
|
||||||
"psnr": 31.802224855380352
|
|
||||||
}
|
|
||||||
@@ -1,5 +0,0 @@
|
|||||||
#\!/bin/bash
|
|
||||||
res_dir="unitree_z1_dual_arm_cleanup_pencils/case1"
|
|
||||||
dataset="unitree_z1_dual_arm_cleanup_pencils"
|
|
||||||
|
|
||||||
TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 CUDA_VISIBLE_DEVICES=0 python3 scripts/evaluation/profile_iteration.py --seed 123 --ckpt_path ckpts/unifolm_wma_dual_mix_bf16.ckpt --config configs/inference/world_model_interaction.yaml --savedir "${res_dir}/profile_output" --prompt_dir "${res_dir}/world_model_interaction_prompts" --dataset ${dataset} --bs 1 --height 320 --width 512 --unconditional_guidance_scale 1.0 --ddim_steps 50 --ddim_eta 1.0 --video_length 16 --frame_stride 4 --exe_steps 16 --n_iter 5 --warmup 1 --timestep_spacing uniform_trailing --guidance_rescale 0.7 --perframe_ae --vae_dtype bf16 --fast_policy_no_decode --csv "${res_dir}/profile_output/baseline.csv" 2>&1 | tee "${res_dir}/profile_output/profile.log"
|
|
||||||
@@ -2,9 +2,9 @@ res_dir="unitree_z1_dual_arm_cleanup_pencils/case1"
|
|||||||
dataset="unitree_z1_dual_arm_cleanup_pencils"
|
dataset="unitree_z1_dual_arm_cleanup_pencils"
|
||||||
|
|
||||||
{
|
{
|
||||||
time TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 CUDA_VISIBLE_DEVICES=0 python3 scripts/evaluation/world_model_interaction.py \
|
time CUDA_VISIBLE_DEVICES=0 python3 scripts/evaluation/world_model_interaction.py \
|
||||||
--seed 123 \
|
--seed 123 \
|
||||||
--ckpt_path ckpts/unifolm_wma_dual_mix_bf16.ckpt \
|
--ckpt_path ckpts/unifolm_wma_dual.ckpt \
|
||||||
--config configs/inference/world_model_interaction.yaml \
|
--config configs/inference/world_model_interaction.yaml \
|
||||||
--savedir "${res_dir}/output" \
|
--savedir "${res_dir}/output" \
|
||||||
--bs 1 --height 320 --width 512 \
|
--bs 1 --height 320 --width 512 \
|
||||||
@@ -21,6 +21,5 @@ dataset="unitree_z1_dual_arm_cleanup_pencils"
|
|||||||
--timestep_spacing 'uniform_trailing' \
|
--timestep_spacing 'uniform_trailing' \
|
||||||
--guidance_rescale 0.7 \
|
--guidance_rescale 0.7 \
|
||||||
--perframe_ae \
|
--perframe_ae \
|
||||||
--vae_dtype bf16 \
|
|
||||||
--fast_policy_no_decode
|
--fast_policy_no_decode
|
||||||
} 2>&1 | tee "${res_dir}/output.log"
|
} 2>&1 | tee "${res_dir}/output.log"
|
||||||
|
|||||||
@@ -1,34 +1,16 @@
|
|||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
|
2026-02-19 19:17:16.282875: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
||||||
__import__("pkg_resources").declare_namespace(__name__)
|
2026-02-19 19:17:16.330519: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
||||||
2026-02-08 06:59:34.465946: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
2026-02-19 19:17:16.330561: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
||||||
2026-02-08 06:59:34.469367: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
2026-02-19 19:17:16.331631: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
||||||
2026-02-08 06:59:34.500805: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
2026-02-19 19:17:16.338413: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
||||||
2026-02-08 06:59:34.500837: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
||||||
2026-02-08 06:59:34.502917: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
2026-02-19 19:17:17.250653: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
||||||
2026-02-08 06:59:34.511434: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
|
||||||
2026-02-08 06:59:34.511678: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
|
||||||
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
|
||||||
2026-02-08 06:59:35.478194: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
|
||||||
Global seed set to 123
|
Global seed set to 123
|
||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
|
>>> Loading prepared model from ckpts/unifolm_wma_dual.ckpt.prepared.pt ...
|
||||||
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
|
>>> Prepared model loaded.
|
||||||
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
|
>>> Diffusion backbone (model.model) converted to FP16.
|
||||||
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
>>> Projectors (image_proj_model, state_projector, action_projector) converted to FP16.
|
||||||
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
>>> Encoders (cond_stage_model, embedder) converted to FP16.
|
||||||
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
|
|
||||||
INFO:root:Loaded ViT-H-14 model config.
|
|
||||||
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
|
|
||||||
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
|
||||||
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
|
||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
|
||||||
checkpoint = torch.load(checkpoint_path, map_location=map_location)
|
|
||||||
INFO:root:Loaded ViT-H-14 model config.
|
|
||||||
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
|
||||||
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
|
||||||
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:86: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
|
||||||
state_dict = torch.load(ckpt, map_location="cpu")
|
|
||||||
>>> model checkpoint loaded.
|
|
||||||
>>> Load pre-trained model ...
|
|
||||||
INFO:root:***** Configing Data *****
|
INFO:root:***** Configing Data *****
|
||||||
>>> unitree_z1_stackbox: 1 data samples loaded.
|
>>> unitree_z1_stackbox: 1 data samples loaded.
|
||||||
>>> unitree_z1_stackbox: data stats loaded.
|
>>> unitree_z1_stackbox: data stats loaded.
|
||||||
@@ -46,69 +28,15 @@ INFO:root:***** Configing Data *****
|
|||||||
>>> unitree_g1_pack_camera: data stats loaded.
|
>>> unitree_g1_pack_camera: data stats loaded.
|
||||||
>>> unitree_g1_pack_camera: normalizer initiated.
|
>>> unitree_g1_pack_camera: normalizer initiated.
|
||||||
>>> Dataset is successfully loaded ...
|
>>> Dataset is successfully loaded ...
|
||||||
|
✓ KV fused: 66 attention layers
|
||||||
>>> Generate 16 frames under each generation ...
|
>>> Generate 16 frames under each generation ...
|
||||||
DEBUG:h5py._conv:Creating converter from 3 to 5
|
DEBUG:h5py._conv:Creating converter from 3 to 5
|
||||||
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
|
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
|
||||||
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
|
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
|
||||||
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
||||||
|
|
||||||
0%| | 0/8 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
|
|
||||||
proj = linear(q, w, b)
|
|
||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Flash attention support on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:225.)
|
|
||||||
attn_output = scaled_dot_product_attention(
|
|
||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)
|
|
||||||
0%| | 0/8 [00:00<?, ?it/s]
|
0%| | 0/8 [00:00<?, ?it/s]
|
||||||
12%|█▎ | 1/8 [00:24<02:48, 24.06s/it]
|
12%|█▎ | 1/8 [00:24<02:48, 24.06s/it]
|
||||||
>>> Step 0: interacting with world model ...
|
|
||||||
DEBUG:PIL.Image:Importing BlpImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing BmpImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing BufrStubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing CurImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing DcxImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing DdsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing EpsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FitsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FitsStubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FliImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FpxImagePlugin
|
|
||||||
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
|
|
||||||
DEBUG:PIL.Image:Importing FtexImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing GbrImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing GifImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing GribStubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing IcnsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing IcoImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing ImImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing ImtImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing IptcImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing JpegImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing McIdasImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing MicImagePlugin
|
|
||||||
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
|
|
||||||
DEBUG:PIL.Image:Importing MpegImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing MpoImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing MspImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PalmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PcdImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PcxImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PdfImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PixarImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PngImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PpmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PsdImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing QoiImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing SgiImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing SpiderImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing SunImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing TgaImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing TiffImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing WebPImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing WmfImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing XbmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing XpmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
|
||||||
25%|██▌ | 2/8 [00:47<02:21, 23.61s/it]
|
25%|██▌ | 2/8 [00:47<02:21, 23.61s/it]
|
||||||
38%|███▊ | 3/8 [01:10<01:57, 23.47s/it]
|
38%|███▊ | 3/8 [01:10<01:57, 23.47s/it]
|
||||||
50%|█████ | 4/8 [01:34<01:33, 23.44s/it]
|
50%|█████ | 4/8 [01:34<01:33, 23.44s/it]
|
||||||
@@ -132,6 +60,6 @@ DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
|||||||
>>> Step 4: generating actions ...
|
>>> Step 4: generating actions ...
|
||||||
>>> Step 4: interacting with world model ...
|
>>> Step 4: interacting with world model ...
|
||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
>>> Step 5: generating actions ...
|
>>> Step 5: generating actions ...
|
||||||
>>> Step 5: interacting with world model ...
|
>>> Step 5: interacting with world model ...
|
||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
{
|
{
|
||||||
"gt_video": "unitree_z1_dual_arm_cleanup_pencils/case2/unitree_z1_dual_arm_cleanup_pencils_case2.mp4",
|
"gt_video": "unitree_z1_dual_arm_cleanup_pencils/case2/unitree_z1_dual_arm_cleanup_pencils_case2.mp4",
|
||||||
"pred_video": "unitree_z1_dual_arm_cleanup_pencils/case2/output/inference/unitree_z1_dual_arm_cleanup_pencils_case2_amd.mp4",
|
"pred_video": "unitree_z1_dual_arm_cleanup_pencils/case2/output/inference/50_full_fs4.mp4",
|
||||||
"psnr": 20.484298972158296
|
"psnr": 44.50028075962896
|
||||||
}
|
}
|
||||||
@@ -20,5 +20,6 @@ dataset="unitree_z1_dual_arm_cleanup_pencils"
|
|||||||
--n_iter 8 \
|
--n_iter 8 \
|
||||||
--timestep_spacing 'uniform_trailing' \
|
--timestep_spacing 'uniform_trailing' \
|
||||||
--guidance_rescale 0.7 \
|
--guidance_rescale 0.7 \
|
||||||
--perframe_ae
|
--perframe_ae \
|
||||||
|
--fast_policy_no_decode
|
||||||
} 2>&1 | tee "${res_dir}/output.log"
|
} 2>&1 | tee "${res_dir}/output.log"
|
||||||
|
|||||||
@@ -1,34 +1,16 @@
|
|||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
|
2026-02-19 19:20:40.444703: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
||||||
__import__("pkg_resources").declare_namespace(__name__)
|
2026-02-19 19:20:40.492237: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
||||||
2026-02-08 07:18:52.629976: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
2026-02-19 19:20:40.492278: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
||||||
2026-02-08 07:18:52.633025: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
2026-02-19 19:20:40.493360: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
||||||
2026-02-08 07:18:52.663985: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
2026-02-19 19:20:40.500130: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
||||||
2026-02-08 07:18:52.664018: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
||||||
2026-02-08 07:18:52.665837: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
2026-02-19 19:20:41.414718: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
||||||
2026-02-08 07:18:52.673889: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
Global seed set to 123
|
||||||
2026-02-08 07:18:52.674218: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
>>> Loading prepared model from ckpts/unifolm_wma_dual.ckpt.prepared.pt ...
|
||||||
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
>>> Prepared model loaded.
|
||||||
2026-02-08 07:18:53.298338: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
>>> Diffusion backbone (model.model) converted to FP16.
|
||||||
[rank: 0] Global seed set to 123
|
>>> Projectors (image_proj_model, state_projector, action_projector) converted to FP16.
|
||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
|
>>> Encoders (cond_stage_model, embedder) converted to FP16.
|
||||||
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
|
|
||||||
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
|
|
||||||
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
|
||||||
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
|
||||||
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
|
|
||||||
INFO:root:Loaded ViT-H-14 model config.
|
|
||||||
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
|
|
||||||
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
|
||||||
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
|
||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
|
||||||
checkpoint = torch.load(checkpoint_path, map_location=map_location)
|
|
||||||
INFO:root:Loaded ViT-H-14 model config.
|
|
||||||
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
|
||||||
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
|
||||||
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:86: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
|
||||||
state_dict = torch.load(ckpt, map_location="cpu")
|
|
||||||
>>> model checkpoint loaded.
|
|
||||||
>>> Load pre-trained model ...
|
|
||||||
INFO:root:***** Configing Data *****
|
INFO:root:***** Configing Data *****
|
||||||
>>> unitree_z1_stackbox: 1 data samples loaded.
|
>>> unitree_z1_stackbox: 1 data samples loaded.
|
||||||
>>> unitree_z1_stackbox: data stats loaded.
|
>>> unitree_z1_stackbox: data stats loaded.
|
||||||
@@ -46,69 +28,15 @@ INFO:root:***** Configing Data *****
|
|||||||
>>> unitree_g1_pack_camera: data stats loaded.
|
>>> unitree_g1_pack_camera: data stats loaded.
|
||||||
>>> unitree_g1_pack_camera: normalizer initiated.
|
>>> unitree_g1_pack_camera: normalizer initiated.
|
||||||
>>> Dataset is successfully loaded ...
|
>>> Dataset is successfully loaded ...
|
||||||
|
✓ KV fused: 66 attention layers
|
||||||
>>> Generate 16 frames under each generation ...
|
>>> Generate 16 frames under each generation ...
|
||||||
DEBUG:h5py._conv:Creating converter from 3 to 5
|
DEBUG:h5py._conv:Creating converter from 3 to 5
|
||||||
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
|
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
|
||||||
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
|
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
|
||||||
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
||||||
|
|
||||||
0%| | 0/8 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
|
|
||||||
proj = linear(q, w, b)
|
|
||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Flash attention support on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:225.)
|
|
||||||
attn_output = scaled_dot_product_attention(
|
|
||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)
|
|
||||||
0%| | 0/8 [00:00<?, ?it/s]
|
0%| | 0/8 [00:00<?, ?it/s]
|
||||||
12%|█▎ | 1/8 [00:24<02:48, 24.06s/it]
|
12%|█▎ | 1/8 [00:24<02:48, 24.06s/it]
|
||||||
>>> Step 0: interacting with world model ...
|
|
||||||
DEBUG:PIL.Image:Importing BlpImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing BmpImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing BufrStubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing CurImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing DcxImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing DdsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing EpsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FitsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FitsStubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FliImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FpxImagePlugin
|
|
||||||
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
|
|
||||||
DEBUG:PIL.Image:Importing FtexImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing GbrImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing GifImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing GribStubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing IcnsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing IcoImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing ImImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing ImtImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing IptcImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing JpegImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing McIdasImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing MicImagePlugin
|
|
||||||
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
|
|
||||||
DEBUG:PIL.Image:Importing MpegImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing MpoImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing MspImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PalmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PcdImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PcxImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PdfImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PixarImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PngImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PpmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PsdImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing QoiImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing SgiImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing SpiderImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing SunImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing TgaImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing TiffImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing WebPImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing WmfImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing XbmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing XpmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
|
||||||
25%|██▌ | 2/8 [00:47<02:21, 23.58s/it]
|
25%|██▌ | 2/8 [00:47<02:21, 23.58s/it]
|
||||||
38%|███▊ | 3/8 [01:10<01:57, 23.45s/it]
|
38%|███▊ | 3/8 [01:10<01:57, 23.45s/it]
|
||||||
50%|█████ | 4/8 [01:33<01:33, 23.41s/it]
|
50%|█████ | 4/8 [01:33<01:33, 23.41s/it]
|
||||||
@@ -132,6 +60,6 @@ DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
|||||||
>>> Step 4: generating actions ...
|
>>> Step 4: generating actions ...
|
||||||
>>> Step 4: interacting with world model ...
|
>>> Step 4: interacting with world model ...
|
||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
>>> Step 5: generating actions ...
|
>>> Step 5: generating actions ...
|
||||||
>>> Step 5: interacting with world model ...
|
>>> Step 5: interacting with world model ...
|
||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
{
|
{
|
||||||
"gt_video": "unitree_z1_dual_arm_cleanup_pencils/case3/unitree_z1_dual_arm_cleanup_pencils_case3.mp4",
|
"gt_video": "unitree_z1_dual_arm_cleanup_pencils/case3/unitree_z1_dual_arm_cleanup_pencils_case3.mp4",
|
||||||
"pred_video": "unitree_z1_dual_arm_cleanup_pencils/case3/output/inference/unitree_z1_dual_arm_cleanup_pencils_case3_amd.mp4",
|
"pred_video": "unitree_z1_dual_arm_cleanup_pencils/case3/output/inference/100_full_fs4.mp4",
|
||||||
"psnr": 21.20205061239349
|
"psnr": 32.29959078097713
|
||||||
}
|
}
|
||||||
@@ -20,5 +20,6 @@ dataset="unitree_z1_dual_arm_cleanup_pencils"
|
|||||||
--n_iter 8 \
|
--n_iter 8 \
|
||||||
--timestep_spacing 'uniform_trailing' \
|
--timestep_spacing 'uniform_trailing' \
|
||||||
--guidance_rescale 0.7 \
|
--guidance_rescale 0.7 \
|
||||||
--perframe_ae
|
--perframe_ae \
|
||||||
|
--fast_policy_no_decode
|
||||||
} 2>&1 | tee "${res_dir}/output.log"
|
} 2>&1 | tee "${res_dir}/output.log"
|
||||||
|
|||||||
@@ -1,34 +1,16 @@
|
|||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
|
2026-02-19 19:24:05.230366: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
||||||
__import__("pkg_resources").declare_namespace(__name__)
|
2026-02-19 19:24:05.278058: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
||||||
2026-02-08 07:22:15.333099: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
2026-02-19 19:24:05.278100: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
||||||
2026-02-08 07:22:15.336215: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
2026-02-19 19:24:05.279133: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
||||||
2026-02-08 07:22:15.366489: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
2026-02-19 19:24:05.285789: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
||||||
2026-02-08 07:22:15.366522: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
||||||
2026-02-08 07:22:15.368294: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
2026-02-19 19:24:06.199101: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
||||||
2026-02-08 07:22:15.376202: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
Global seed set to 123
|
||||||
2026-02-08 07:22:15.376444: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
>>> Loading prepared model from ckpts/unifolm_wma_dual.ckpt.prepared.pt ...
|
||||||
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
>>> Prepared model loaded.
|
||||||
2026-02-08 07:22:15.995383: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
>>> Diffusion backbone (model.model) converted to FP16.
|
||||||
[rank: 0] Global seed set to 123
|
>>> Projectors (image_proj_model, state_projector, action_projector) converted to FP16.
|
||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
|
>>> Encoders (cond_stage_model, embedder) converted to FP16.
|
||||||
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
|
|
||||||
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
|
|
||||||
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
|
||||||
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
|
||||||
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
|
|
||||||
INFO:root:Loaded ViT-H-14 model config.
|
|
||||||
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
|
|
||||||
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
|
||||||
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
|
||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
|
||||||
checkpoint = torch.load(checkpoint_path, map_location=map_location)
|
|
||||||
INFO:root:Loaded ViT-H-14 model config.
|
|
||||||
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
|
||||||
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
|
||||||
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:86: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
|
||||||
state_dict = torch.load(ckpt, map_location="cpu")
|
|
||||||
>>> model checkpoint loaded.
|
|
||||||
>>> Load pre-trained model ...
|
|
||||||
INFO:root:***** Configing Data *****
|
INFO:root:***** Configing Data *****
|
||||||
>>> unitree_z1_stackbox: 1 data samples loaded.
|
>>> unitree_z1_stackbox: 1 data samples loaded.
|
||||||
>>> unitree_z1_stackbox: data stats loaded.
|
>>> unitree_z1_stackbox: data stats loaded.
|
||||||
@@ -46,69 +28,15 @@ INFO:root:***** Configing Data *****
|
|||||||
>>> unitree_g1_pack_camera: data stats loaded.
|
>>> unitree_g1_pack_camera: data stats loaded.
|
||||||
>>> unitree_g1_pack_camera: normalizer initiated.
|
>>> unitree_g1_pack_camera: normalizer initiated.
|
||||||
>>> Dataset is successfully loaded ...
|
>>> Dataset is successfully loaded ...
|
||||||
|
✓ KV fused: 66 attention layers
|
||||||
>>> Generate 16 frames under each generation ...
|
>>> Generate 16 frames under each generation ...
|
||||||
DEBUG:h5py._conv:Creating converter from 3 to 5
|
DEBUG:h5py._conv:Creating converter from 3 to 5
|
||||||
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
|
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
|
||||||
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
|
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
|
||||||
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
||||||
|
|
||||||
0%| | 0/8 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
|
|
||||||
proj = linear(q, w, b)
|
|
||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Flash attention support on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:225.)
|
|
||||||
attn_output = scaled_dot_product_attention(
|
|
||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)
|
|
||||||
0%| | 0/8 [00:00<?, ?it/s]
|
0%| | 0/8 [00:00<?, ?it/s]
|
||||||
12%|█▎ | 1/8 [00:24<02:48, 24.06s/it]
|
12%|█▎ | 1/8 [00:24<02:48, 24.06s/it]
|
||||||
>>> Step 0: interacting with world model ...
|
|
||||||
DEBUG:PIL.Image:Importing BlpImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing BmpImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing BufrStubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing CurImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing DcxImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing DdsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing EpsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FitsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FitsStubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FliImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FpxImagePlugin
|
|
||||||
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
|
|
||||||
DEBUG:PIL.Image:Importing FtexImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing GbrImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing GifImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing GribStubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing IcnsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing IcoImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing ImImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing ImtImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing IptcImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing JpegImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing McIdasImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing MicImagePlugin
|
|
||||||
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
|
|
||||||
DEBUG:PIL.Image:Importing MpegImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing MpoImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing MspImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PalmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PcdImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PcxImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PdfImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PixarImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PngImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PpmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PsdImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing QoiImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing SgiImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing SpiderImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing SunImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing TgaImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing TiffImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing WebPImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing WmfImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing XbmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing XpmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
|
||||||
25%|██▌ | 2/8 [00:47<02:21, 23.56s/it]
|
25%|██▌ | 2/8 [00:47<02:21, 23.56s/it]
|
||||||
38%|███▊ | 3/8 [01:10<01:57, 23.45s/it]
|
38%|███▊ | 3/8 [01:10<01:57, 23.45s/it]
|
||||||
50%|█████ | 4/8 [01:33<01:33, 23.43s/it]
|
50%|█████ | 4/8 [01:33<01:33, 23.43s/it]
|
||||||
@@ -132,6 +60,6 @@ DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
|||||||
>>> Step 4: generating actions ...
|
>>> Step 4: generating actions ...
|
||||||
>>> Step 4: interacting with world model ...
|
>>> Step 4: interacting with world model ...
|
||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
>>> Step 5: generating actions ...
|
>>> Step 5: generating actions ...
|
||||||
>>> Step 5: interacting with world model ...
|
>>> Step 5: interacting with world model ...
|
||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
{
|
{
|
||||||
"gt_video": "unitree_z1_dual_arm_cleanup_pencils/case4/unitree_z1_dual_arm_cleanup_pencils_case4.mp4",
|
"gt_video": "unitree_z1_dual_arm_cleanup_pencils/case4/unitree_z1_dual_arm_cleanup_pencils_case4.mp4",
|
||||||
"pred_video": "unitree_z1_dual_arm_cleanup_pencils/case4/output/inference/unitree_z1_dual_arm_cleanup_pencils_case4_amd.mp4",
|
"pred_video": "unitree_z1_dual_arm_cleanup_pencils/case4/output/inference/200_full_fs4.mp4",
|
||||||
"psnr": 21.130122583788612
|
"psnr": 45.051241961122535
|
||||||
}
|
}
|
||||||
@@ -20,5 +20,6 @@ dataset="unitree_z1_dual_arm_cleanup_pencils"
|
|||||||
--n_iter 8 \
|
--n_iter 8 \
|
||||||
--timestep_spacing 'uniform_trailing' \
|
--timestep_spacing 'uniform_trailing' \
|
||||||
--guidance_rescale 0.7 \
|
--guidance_rescale 0.7 \
|
||||||
--perframe_ae
|
--perframe_ae \
|
||||||
|
--fast_policy_no_decode
|
||||||
} 2>&1 | tee "${res_dir}/output.log"
|
} 2>&1 | tee "${res_dir}/output.log"
|
||||||
|
|||||||
@@ -1,34 +1,16 @@
|
|||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
|
2026-02-19 19:27:29.317502: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
||||||
__import__("pkg_resources").declare_namespace(__name__)
|
2026-02-19 19:27:29.365030: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
||||||
2026-02-08 07:24:40.357099: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
2026-02-19 19:27:29.365079: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
||||||
2026-02-08 07:24:40.360365: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
2026-02-19 19:27:29.366111: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
||||||
2026-02-08 07:24:40.391744: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
2026-02-19 19:27:29.372733: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
||||||
2026-02-08 07:24:40.391772: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
||||||
2026-02-08 07:24:40.393608: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
2026-02-19 19:27:30.291220: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
||||||
2026-02-08 07:24:40.401837: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
|
||||||
2026-02-08 07:24:40.402077: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
|
||||||
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
|
||||||
2026-02-08 07:24:41.022382: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
|
||||||
Global seed set to 123
|
Global seed set to 123
|
||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
|
>>> Loading prepared model from ckpts/unifolm_wma_dual.ckpt.prepared.pt ...
|
||||||
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
|
>>> Prepared model loaded.
|
||||||
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
|
>>> Diffusion backbone (model.model) converted to FP16.
|
||||||
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
>>> Projectors (image_proj_model, state_projector, action_projector) converted to FP16.
|
||||||
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
>>> Encoders (cond_stage_model, embedder) converted to FP16.
|
||||||
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
|
|
||||||
INFO:root:Loaded ViT-H-14 model config.
|
|
||||||
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
|
|
||||||
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
|
||||||
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
|
||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
|
||||||
checkpoint = torch.load(checkpoint_path, map_location=map_location)
|
|
||||||
INFO:root:Loaded ViT-H-14 model config.
|
|
||||||
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
|
||||||
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
|
||||||
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:86: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
|
||||||
state_dict = torch.load(ckpt, map_location="cpu")
|
|
||||||
>>> model checkpoint loaded.
|
|
||||||
>>> Load pre-trained model ...
|
|
||||||
INFO:root:***** Configing Data *****
|
INFO:root:***** Configing Data *****
|
||||||
>>> unitree_z1_stackbox: 1 data samples loaded.
|
>>> unitree_z1_stackbox: 1 data samples loaded.
|
||||||
>>> unitree_z1_stackbox: data stats loaded.
|
>>> unitree_z1_stackbox: data stats loaded.
|
||||||
@@ -46,69 +28,15 @@ INFO:root:***** Configing Data *****
|
|||||||
>>> unitree_g1_pack_camera: data stats loaded.
|
>>> unitree_g1_pack_camera: data stats loaded.
|
||||||
>>> unitree_g1_pack_camera: normalizer initiated.
|
>>> unitree_g1_pack_camera: normalizer initiated.
|
||||||
>>> Dataset is successfully loaded ...
|
>>> Dataset is successfully loaded ...
|
||||||
|
✓ KV fused: 66 attention layers
|
||||||
>>> Generate 16 frames under each generation ...
|
>>> Generate 16 frames under each generation ...
|
||||||
DEBUG:h5py._conv:Creating converter from 3 to 5
|
DEBUG:h5py._conv:Creating converter from 3 to 5
|
||||||
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
|
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
|
||||||
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
|
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
|
||||||
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
||||||
|
|
||||||
0%| | 0/7 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
|
|
||||||
proj = linear(q, w, b)
|
|
||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Flash attention support on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:225.)
|
|
||||||
attn_output = scaled_dot_product_attention(
|
|
||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)
|
|
||||||
0%| | 0/7 [00:00<?, ?it/s]
|
0%| | 0/7 [00:00<?, ?it/s]
|
||||||
14%|█▍ | 1/7 [00:24<02:24, 24.09s/it]
|
14%|█▍ | 1/7 [00:24<02:24, 24.09s/it]
|
||||||
>>> Step 0: interacting with world model ...
|
|
||||||
DEBUG:PIL.Image:Importing BlpImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing BmpImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing BufrStubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing CurImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing DcxImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing DdsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing EpsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FitsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FitsStubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FliImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FpxImagePlugin
|
|
||||||
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
|
|
||||||
DEBUG:PIL.Image:Importing FtexImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing GbrImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing GifImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing GribStubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing IcnsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing IcoImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing ImImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing ImtImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing IptcImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing JpegImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing McIdasImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing MicImagePlugin
|
|
||||||
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
|
|
||||||
DEBUG:PIL.Image:Importing MpegImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing MpoImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing MspImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PalmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PcdImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PcxImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PdfImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PixarImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PngImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PpmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PsdImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing QoiImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing SgiImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing SpiderImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing SunImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing TgaImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing TiffImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing WebPImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing WmfImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing XbmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing XpmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
|
||||||
29%|██▊ | 2/7 [00:47<01:57, 23.59s/it]
|
29%|██▊ | 2/7 [00:47<01:57, 23.59s/it]
|
||||||
43%|████▎ | 3/7 [01:10<01:33, 23.46s/it]
|
43%|████▎ | 3/7 [01:10<01:33, 23.46s/it]
|
||||||
57%|█████▋ | 4/7 [01:33<01:10, 23.42s/it]
|
57%|█████▋ | 4/7 [01:33<01:10, 23.42s/it]
|
||||||
@@ -129,6 +57,6 @@ DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
|||||||
>>> Step 3: interacting with world model ...
|
>>> Step 3: interacting with world model ...
|
||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
>>> Step 4: generating actions ...
|
>>> Step 4: generating actions ...
|
||||||
>>> Step 4: interacting with world model ...
|
>>> Step 4: interacting with world model ...
|
||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
>>> Step 5: generating actions ...
|
>>> Step 5: generating actions ...
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
{
|
{
|
||||||
"gt_video": "unitree_z1_dual_arm_stackbox/case1/unitree_z1_dual_arm_stackbox_case1.mp4",
|
"gt_video": "unitree_z1_dual_arm_stackbox/case1/unitree_z1_dual_arm_stackbox_case1.mp4",
|
||||||
"pred_video": "unitree_z1_dual_arm_stackbox/case1/output/inference/unitree_z1_dual_arm_stackbox_case1_amd.mp4",
|
"pred_video": "unitree_z1_dual_arm_stackbox/case1/output/inference/5_full_fs4.mp4",
|
||||||
"psnr": 21.258130518117493
|
"psnr": 42.717688631296596
|
||||||
}
|
}
|
||||||
@@ -2,7 +2,7 @@ res_dir="unitree_z1_dual_arm_stackbox/case1"
|
|||||||
dataset="unitree_z1_dual_arm_stackbox"
|
dataset="unitree_z1_dual_arm_stackbox"
|
||||||
|
|
||||||
{
|
{
|
||||||
time CUDA_VISIBLE_DEVICES=7 python3 scripts/evaluation/world_model_interaction.py \
|
time CUDA_VISIBLE_DEVICES=0 python3 scripts/evaluation/world_model_interaction.py \
|
||||||
--seed 123 \
|
--seed 123 \
|
||||||
--ckpt_path ckpts/unifolm_wma_dual.ckpt \
|
--ckpt_path ckpts/unifolm_wma_dual.ckpt \
|
||||||
--config configs/inference/world_model_interaction.yaml \
|
--config configs/inference/world_model_interaction.yaml \
|
||||||
@@ -20,5 +20,6 @@ dataset="unitree_z1_dual_arm_stackbox"
|
|||||||
--n_iter 7 \
|
--n_iter 7 \
|
||||||
--timestep_spacing 'uniform_trailing' \
|
--timestep_spacing 'uniform_trailing' \
|
||||||
--guidance_rescale 0.7 \
|
--guidance_rescale 0.7 \
|
||||||
--perframe_ae
|
--perframe_ae \
|
||||||
|
--fast_policy_no_decode
|
||||||
} 2>&1 | tee "${res_dir}/output.log"
|
} 2>&1 | tee "${res_dir}/output.log"
|
||||||
|
|||||||
@@ -1,34 +1,16 @@
|
|||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
|
2026-02-19 19:30:30.058862: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
||||||
__import__("pkg_resources").declare_namespace(__name__)
|
2026-02-19 19:30:30.106200: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
||||||
2026-02-08 07:25:18.653033: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
2026-02-19 19:30:30.106243: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
||||||
2026-02-08 07:25:18.656060: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
2026-02-19 19:30:30.107276: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
||||||
2026-02-08 07:25:18.687077: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
2026-02-19 19:30:30.113917: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
||||||
2026-02-08 07:25:18.687119: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
||||||
2026-02-08 07:25:18.688915: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
2026-02-19 19:30:31.026240: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
||||||
2026-02-08 07:25:18.697008: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
|
||||||
2026-02-08 07:25:18.697255: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
|
||||||
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
|
||||||
2026-02-08 07:25:19.338303: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
|
||||||
Global seed set to 123
|
Global seed set to 123
|
||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
|
>>> Loading prepared model from ckpts/unifolm_wma_dual.ckpt.prepared.pt ...
|
||||||
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
|
>>> Prepared model loaded.
|
||||||
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
|
>>> Diffusion backbone (model.model) converted to FP16.
|
||||||
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
>>> Projectors (image_proj_model, state_projector, action_projector) converted to FP16.
|
||||||
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
>>> Encoders (cond_stage_model, embedder) converted to FP16.
|
||||||
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
|
|
||||||
INFO:root:Loaded ViT-H-14 model config.
|
|
||||||
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
|
|
||||||
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
|
||||||
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
|
||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
|
||||||
checkpoint = torch.load(checkpoint_path, map_location=map_location)
|
|
||||||
INFO:root:Loaded ViT-H-14 model config.
|
|
||||||
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
|
||||||
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
|
||||||
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:86: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
|
||||||
state_dict = torch.load(ckpt, map_location="cpu")
|
|
||||||
>>> model checkpoint loaded.
|
|
||||||
>>> Load pre-trained model ...
|
|
||||||
INFO:root:***** Configing Data *****
|
INFO:root:***** Configing Data *****
|
||||||
>>> unitree_z1_stackbox: 1 data samples loaded.
|
>>> unitree_z1_stackbox: 1 data samples loaded.
|
||||||
>>> unitree_z1_stackbox: data stats loaded.
|
>>> unitree_z1_stackbox: data stats loaded.
|
||||||
@@ -46,69 +28,15 @@ INFO:root:***** Configing Data *****
|
|||||||
>>> unitree_g1_pack_camera: data stats loaded.
|
>>> unitree_g1_pack_camera: data stats loaded.
|
||||||
>>> unitree_g1_pack_camera: normalizer initiated.
|
>>> unitree_g1_pack_camera: normalizer initiated.
|
||||||
>>> Dataset is successfully loaded ...
|
>>> Dataset is successfully loaded ...
|
||||||
|
✓ KV fused: 66 attention layers
|
||||||
>>> Generate 16 frames under each generation ...
|
>>> Generate 16 frames under each generation ...
|
||||||
DEBUG:h5py._conv:Creating converter from 3 to 5
|
DEBUG:h5py._conv:Creating converter from 3 to 5
|
||||||
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
|
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
|
||||||
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
|
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
|
||||||
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
||||||
|
|
||||||
0%| | 0/7 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
|
|
||||||
proj = linear(q, w, b)
|
|
||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Flash attention support on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:225.)
|
|
||||||
attn_output = scaled_dot_product_attention(
|
|
||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)
|
|
||||||
0%| | 0/7 [00:00<?, ?it/s]
|
0%| | 0/7 [00:00<?, ?it/s]
|
||||||
14%|█▍ | 1/7 [00:24<02:24, 24.09s/it]
|
14%|█▍ | 1/7 [00:24<02:24, 24.09s/it]
|
||||||
>>> Step 0: interacting with world model ...
|
|
||||||
DEBUG:PIL.Image:Importing BlpImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing BmpImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing BufrStubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing CurImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing DcxImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing DdsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing EpsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FitsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FitsStubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FliImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FpxImagePlugin
|
|
||||||
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
|
|
||||||
DEBUG:PIL.Image:Importing FtexImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing GbrImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing GifImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing GribStubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing IcnsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing IcoImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing ImImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing ImtImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing IptcImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing JpegImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing McIdasImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing MicImagePlugin
|
|
||||||
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
|
|
||||||
DEBUG:PIL.Image:Importing MpegImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing MpoImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing MspImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PalmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PcdImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PcxImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PdfImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PixarImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PngImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PpmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PsdImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing QoiImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing SgiImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing SpiderImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing SunImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing TgaImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing TiffImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing WebPImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing WmfImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing XbmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing XpmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
|
||||||
29%|██▊ | 2/7 [00:47<01:58, 23.60s/it]
|
29%|██▊ | 2/7 [00:47<01:58, 23.60s/it]
|
||||||
43%|████▎ | 3/7 [01:10<01:33, 23.48s/it]
|
43%|████▎ | 3/7 [01:10<01:33, 23.48s/it]
|
||||||
57%|█████▋ | 4/7 [01:34<01:10, 23.43s/it]
|
57%|█████▋ | 4/7 [01:34<01:10, 23.43s/it]
|
||||||
@@ -129,6 +57,6 @@ DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
|||||||
>>> Step 3: interacting with world model ...
|
>>> Step 3: interacting with world model ...
|
||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
>>> Step 4: generating actions ...
|
>>> Step 4: generating actions ...
|
||||||
>>> Step 4: interacting with world model ...
|
>>> Step 4: interacting with world model ...
|
||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
>>> Step 5: generating actions ...
|
>>> Step 5: generating actions ...
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
{
|
{
|
||||||
"gt_video": "unitree_z1_dual_arm_stackbox/case2/unitree_z1_dual_arm_stackbox_case2.mp4",
|
"gt_video": "unitree_z1_dual_arm_stackbox/case2/unitree_z1_dual_arm_stackbox_case2.mp4",
|
||||||
"pred_video": "unitree_z1_dual_arm_stackbox/case2/output/inference/unitree_z1_dual_arm_stackbox_case2_amd.mp4",
|
"pred_video": "unitree_z1_dual_arm_stackbox/case2/output/inference/15_full_fs4.mp4",
|
||||||
"psnr": 23.878153424077645
|
"psnr": 44.90750363879194
|
||||||
}
|
}
|
||||||
@@ -2,7 +2,7 @@ res_dir="unitree_z1_dual_arm_stackbox/case2"
|
|||||||
dataset="unitree_z1_dual_arm_stackbox"
|
dataset="unitree_z1_dual_arm_stackbox"
|
||||||
|
|
||||||
{
|
{
|
||||||
time CUDA_VISIBLE_DEVICES=6 python3 scripts/evaluation/world_model_interaction.py \
|
time CUDA_VISIBLE_DEVICES=0 python3 scripts/evaluation/world_model_interaction.py \
|
||||||
--seed 123 \
|
--seed 123 \
|
||||||
--ckpt_path ckpts/unifolm_wma_dual.ckpt \
|
--ckpt_path ckpts/unifolm_wma_dual.ckpt \
|
||||||
--config configs/inference/world_model_interaction.yaml \
|
--config configs/inference/world_model_interaction.yaml \
|
||||||
@@ -20,5 +20,6 @@ dataset="unitree_z1_dual_arm_stackbox"
|
|||||||
--n_iter 7 \
|
--n_iter 7 \
|
||||||
--timestep_spacing 'uniform_trailing' \
|
--timestep_spacing 'uniform_trailing' \
|
||||||
--guidance_rescale 0.7 \
|
--guidance_rescale 0.7 \
|
||||||
--perframe_ae
|
--perframe_ae \
|
||||||
|
--fast_policy_no_decode
|
||||||
} 2>&1 | tee "${res_dir}/output.log"
|
} 2>&1 | tee "${res_dir}/output.log"
|
||||||
|
|||||||
@@ -1,34 +1,16 @@
|
|||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
|
2026-02-19 19:33:31.235859: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
||||||
__import__("pkg_resources").declare_namespace(__name__)
|
2026-02-19 19:33:31.283866: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
||||||
2026-02-08 07:35:33.682231: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
2026-02-19 19:33:31.283908: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
||||||
2026-02-08 07:35:33.685275: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
2026-02-19 19:33:31.284941: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
||||||
2026-02-08 07:35:33.716682: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
2026-02-19 19:33:31.291610: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
||||||
2026-02-08 07:35:33.716728: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
||||||
2026-02-08 07:35:33.718523: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
2026-02-19 19:33:32.199716: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
||||||
2026-02-08 07:35:33.726756: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
Global seed set to 123
|
||||||
2026-02-08 07:35:33.727105: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
>>> Loading prepared model from ckpts/unifolm_wma_dual.ckpt.prepared.pt ...
|
||||||
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
>>> Prepared model loaded.
|
||||||
2026-02-08 07:35:34.356722: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
>>> Diffusion backbone (model.model) converted to FP16.
|
||||||
[rank: 0] Global seed set to 123
|
>>> Projectors (image_proj_model, state_projector, action_projector) converted to FP16.
|
||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
|
>>> Encoders (cond_stage_model, embedder) converted to FP16.
|
||||||
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
|
|
||||||
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
|
|
||||||
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
|
||||||
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
|
||||||
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
|
|
||||||
INFO:root:Loaded ViT-H-14 model config.
|
|
||||||
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
|
|
||||||
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
|
||||||
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
|
||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
|
||||||
checkpoint = torch.load(checkpoint_path, map_location=map_location)
|
|
||||||
INFO:root:Loaded ViT-H-14 model config.
|
|
||||||
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
|
||||||
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
|
||||||
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:86: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
|
||||||
state_dict = torch.load(ckpt, map_location="cpu")
|
|
||||||
>>> model checkpoint loaded.
|
|
||||||
>>> Load pre-trained model ...
|
|
||||||
INFO:root:***** Configing Data *****
|
INFO:root:***** Configing Data *****
|
||||||
>>> unitree_z1_stackbox: 1 data samples loaded.
|
>>> unitree_z1_stackbox: 1 data samples loaded.
|
||||||
>>> unitree_z1_stackbox: data stats loaded.
|
>>> unitree_z1_stackbox: data stats loaded.
|
||||||
@@ -46,69 +28,15 @@ INFO:root:***** Configing Data *****
|
|||||||
>>> unitree_g1_pack_camera: data stats loaded.
|
>>> unitree_g1_pack_camera: data stats loaded.
|
||||||
>>> unitree_g1_pack_camera: normalizer initiated.
|
>>> unitree_g1_pack_camera: normalizer initiated.
|
||||||
>>> Dataset is successfully loaded ...
|
>>> Dataset is successfully loaded ...
|
||||||
|
✓ KV fused: 66 attention layers
|
||||||
>>> Generate 16 frames under each generation ...
|
>>> Generate 16 frames under each generation ...
|
||||||
DEBUG:h5py._conv:Creating converter from 3 to 5
|
DEBUG:h5py._conv:Creating converter from 3 to 5
|
||||||
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
|
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
|
||||||
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
|
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
|
||||||
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
||||||
|
|
||||||
0%| | 0/7 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
|
|
||||||
proj = linear(q, w, b)
|
|
||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Flash attention support on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:225.)
|
|
||||||
attn_output = scaled_dot_product_attention(
|
|
||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)
|
|
||||||
0%| | 0/7 [00:00<?, ?it/s]
|
0%| | 0/7 [00:00<?, ?it/s]
|
||||||
14%|█▍ | 1/7 [00:24<02:24, 24.10s/it]
|
14%|█▍ | 1/7 [00:24<02:24, 24.10s/it]
|
||||||
>>> Step 0: interacting with world model ...
|
|
||||||
DEBUG:PIL.Image:Importing BlpImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing BmpImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing BufrStubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing CurImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing DcxImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing DdsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing EpsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FitsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FitsStubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FliImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FpxImagePlugin
|
|
||||||
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
|
|
||||||
DEBUG:PIL.Image:Importing FtexImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing GbrImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing GifImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing GribStubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing IcnsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing IcoImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing ImImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing ImtImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing IptcImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing JpegImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing McIdasImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing MicImagePlugin
|
|
||||||
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
|
|
||||||
DEBUG:PIL.Image:Importing MpegImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing MpoImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing MspImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PalmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PcdImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PcxImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PdfImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PixarImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PngImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PpmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PsdImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing QoiImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing SgiImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing SpiderImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing SunImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing TgaImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing TiffImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing WebPImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing WmfImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing XbmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing XpmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
|
||||||
29%|██▊ | 2/7 [00:47<01:58, 23.62s/it]
|
29%|██▊ | 2/7 [00:47<01:58, 23.62s/it]
|
||||||
43%|████▎ | 3/7 [01:10<01:34, 23.51s/it]
|
43%|████▎ | 3/7 [01:10<01:34, 23.51s/it]
|
||||||
57%|█████▋ | 4/7 [01:34<01:10, 23.46s/it]
|
57%|█████▋ | 4/7 [01:34<01:10, 23.46s/it]
|
||||||
@@ -129,6 +57,6 @@ DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
|||||||
>>> Step 3: interacting with world model ...
|
>>> Step 3: interacting with world model ...
|
||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
>>> Step 4: generating actions ...
|
>>> Step 4: generating actions ...
|
||||||
>>> Step 4: interacting with world model ...
|
>>> Step 4: interacting with world model ...
|
||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
>>> Step 5: generating actions ...
|
>>> Step 5: generating actions ...
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
{
|
{
|
||||||
"gt_video": "unitree_z1_dual_arm_stackbox/case3/unitree_z1_dual_arm_stackbox_case3.mp4",
|
"gt_video": "unitree_z1_dual_arm_stackbox/case3/unitree_z1_dual_arm_stackbox_case3.mp4",
|
||||||
"pred_video": "unitree_z1_dual_arm_stackbox/case3/output/inference/unitree_z1_dual_arm_stackbox_case3_amd.mp4",
|
"pred_video": "unitree_z1_dual_arm_stackbox/case3/output/inference/25_full_fs4.mp4",
|
||||||
"psnr": 25.400458754751128
|
"psnr": 39.63695040491171
|
||||||
}
|
}
|
||||||
@@ -20,5 +20,6 @@ dataset="unitree_z1_dual_arm_stackbox"
|
|||||||
--n_iter 7 \
|
--n_iter 7 \
|
||||||
--timestep_spacing 'uniform_trailing' \
|
--timestep_spacing 'uniform_trailing' \
|
||||||
--guidance_rescale 0.7 \
|
--guidance_rescale 0.7 \
|
||||||
--perframe_ae
|
--perframe_ae \
|
||||||
|
--fast_policy_no_decode
|
||||||
} 2>&1 | tee "${res_dir}/output.log"
|
} 2>&1 | tee "${res_dir}/output.log"
|
||||||
|
|||||||
@@ -1,34 +1,16 @@
|
|||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
|
2026-02-19 19:36:32.251051: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
||||||
__import__("pkg_resources").declare_namespace(__name__)
|
2026-02-19 19:36:32.298464: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
||||||
2026-02-08 07:38:45.572744: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
2026-02-19 19:36:32.298506: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
||||||
2026-02-08 07:38:45.576864: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
2026-02-19 19:36:32.299538: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
||||||
2026-02-08 07:38:45.624825: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
2026-02-19 19:36:32.306168: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
||||||
2026-02-08 07:38:45.624883: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
||||||
2026-02-08 07:38:45.627150: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
2026-02-19 19:36:33.213503: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
||||||
2026-02-08 07:38:45.638316: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
Global seed set to 123
|
||||||
2026-02-08 07:38:45.638803: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
>>> Loading prepared model from ckpts/unifolm_wma_dual.ckpt.prepared.pt ...
|
||||||
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
>>> Prepared model loaded.
|
||||||
2026-02-08 07:38:46.426363: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
>>> Diffusion backbone (model.model) converted to FP16.
|
||||||
[rank: 0] Global seed set to 123
|
>>> Projectors (image_proj_model, state_projector, action_projector) converted to FP16.
|
||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
|
>>> Encoders (cond_stage_model, embedder) converted to FP16.
|
||||||
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
|
|
||||||
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
|
|
||||||
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
|
||||||
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
|
||||||
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
|
|
||||||
INFO:root:Loaded ViT-H-14 model config.
|
|
||||||
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
|
|
||||||
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
|
||||||
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
|
||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
|
||||||
checkpoint = torch.load(checkpoint_path, map_location=map_location)
|
|
||||||
INFO:root:Loaded ViT-H-14 model config.
|
|
||||||
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
|
||||||
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
|
||||||
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:86: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
|
||||||
state_dict = torch.load(ckpt, map_location="cpu")
|
|
||||||
>>> model checkpoint loaded.
|
|
||||||
>>> Load pre-trained model ...
|
|
||||||
INFO:root:***** Configing Data *****
|
INFO:root:***** Configing Data *****
|
||||||
>>> unitree_z1_stackbox: 1 data samples loaded.
|
>>> unitree_z1_stackbox: 1 data samples loaded.
|
||||||
>>> unitree_z1_stackbox: data stats loaded.
|
>>> unitree_z1_stackbox: data stats loaded.
|
||||||
@@ -46,69 +28,15 @@ INFO:root:***** Configing Data *****
|
|||||||
>>> unitree_g1_pack_camera: data stats loaded.
|
>>> unitree_g1_pack_camera: data stats loaded.
|
||||||
>>> unitree_g1_pack_camera: normalizer initiated.
|
>>> unitree_g1_pack_camera: normalizer initiated.
|
||||||
>>> Dataset is successfully loaded ...
|
>>> Dataset is successfully loaded ...
|
||||||
|
✓ KV fused: 66 attention layers
|
||||||
>>> Generate 16 frames under each generation ...
|
>>> Generate 16 frames under each generation ...
|
||||||
DEBUG:h5py._conv:Creating converter from 3 to 5
|
DEBUG:h5py._conv:Creating converter from 3 to 5
|
||||||
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
|
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
|
||||||
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
|
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
|
||||||
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
||||||
|
|
||||||
0%| | 0/7 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
|
|
||||||
proj = linear(q, w, b)
|
|
||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Flash attention support on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:225.)
|
|
||||||
attn_output = scaled_dot_product_attention(
|
|
||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)
|
|
||||||
0%| | 0/7 [00:00<?, ?it/s]
|
0%| | 0/7 [00:00<?, ?it/s]
|
||||||
14%|█▍ | 1/7 [00:24<02:24, 24.05s/it]
|
14%|█▍ | 1/7 [00:24<02:24, 24.05s/it]
|
||||||
>>> Step 0: interacting with world model ...
|
|
||||||
DEBUG:PIL.Image:Importing BlpImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing BmpImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing BufrStubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing CurImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing DcxImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing DdsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing EpsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FitsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FitsStubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FliImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FpxImagePlugin
|
|
||||||
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
|
|
||||||
DEBUG:PIL.Image:Importing FtexImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing GbrImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing GifImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing GribStubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing IcnsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing IcoImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing ImImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing ImtImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing IptcImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing JpegImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing McIdasImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing MicImagePlugin
|
|
||||||
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
|
|
||||||
DEBUG:PIL.Image:Importing MpegImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing MpoImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing MspImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PalmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PcdImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PcxImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PdfImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PixarImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PngImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PpmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PsdImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing QoiImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing SgiImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing SpiderImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing SunImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing TgaImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing TiffImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing WebPImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing WmfImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing XbmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing XpmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
|
||||||
29%|██▊ | 2/7 [00:47<01:57, 23.58s/it]
|
29%|██▊ | 2/7 [00:47<01:57, 23.58s/it]
|
||||||
43%|████▎ | 3/7 [01:10<01:33, 23.45s/it]
|
43%|████▎ | 3/7 [01:10<01:33, 23.45s/it]
|
||||||
57%|█████▋ | 4/7 [01:33<01:10, 23.43s/it]
|
57%|█████▋ | 4/7 [01:33<01:10, 23.43s/it]
|
||||||
@@ -129,6 +57,6 @@ DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
|||||||
>>> Step 3: interacting with world model ...
|
>>> Step 3: interacting with world model ...
|
||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
>>> Step 4: generating actions ...
|
>>> Step 4: generating actions ...
|
||||||
>>> Step 4: interacting with world model ...
|
>>> Step 4: interacting with world model ...
|
||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
>>> Step 5: generating actions ...
|
>>> Step 5: generating actions ...
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
{
|
{
|
||||||
"gt_video": "unitree_z1_dual_arm_stackbox/case4/unitree_z1_dual_arm_stackbox_case4.mp4",
|
"gt_video": "unitree_z1_dual_arm_stackbox/case4/unitree_z1_dual_arm_stackbox_case4.mp4",
|
||||||
"pred_video": "unitree_z1_dual_arm_stackbox/case4/output/inference/unitree_z1_dual_arm_stackbox_case4_amd.mp4",
|
"pred_video": "unitree_z1_dual_arm_stackbox/case4/output/inference/35_full_fs4.mp4",
|
||||||
"psnr": 24.098958457373858
|
"psnr": 42.34177660061245
|
||||||
}
|
}
|
||||||
@@ -20,5 +20,6 @@ dataset="unitree_z1_dual_arm_stackbox"
|
|||||||
--n_iter 7 \
|
--n_iter 7 \
|
||||||
--timestep_spacing 'uniform_trailing' \
|
--timestep_spacing 'uniform_trailing' \
|
||||||
--guidance_rescale 0.7 \
|
--guidance_rescale 0.7 \
|
||||||
--perframe_ae
|
--perframe_ae \
|
||||||
|
--fast_policy_no_decode
|
||||||
} 2>&1 | tee "${res_dir}/output.log"
|
} 2>&1 | tee "${res_dir}/output.log"
|
||||||
|
|||||||
@@ -1,34 +1,16 @@
|
|||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
|
2026-02-19 19:39:32.908698: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
||||||
__import__("pkg_resources").declare_namespace(__name__)
|
2026-02-19 19:39:32.956378: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
||||||
2026-02-08 07:51:23.961486: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
2026-02-19 19:39:32.956417: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
||||||
2026-02-08 07:51:24.200063: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
2026-02-19 19:39:32.957459: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
||||||
2026-02-08 07:51:24.522299: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
2026-02-19 19:39:32.964104: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
||||||
2026-02-08 07:51:24.522350: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
||||||
2026-02-08 07:51:24.528237: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
2026-02-19 19:39:33.875854: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
||||||
2026-02-08 07:51:24.579400: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
|
||||||
2026-02-08 07:51:24.579644: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
|
||||||
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
|
||||||
2026-02-08 07:51:25.781311: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
|
||||||
Global seed set to 123
|
Global seed set to 123
|
||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
|
>>> Loading prepared model from ckpts/unifolm_wma_dual.ckpt.prepared.pt ...
|
||||||
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
|
>>> Prepared model loaded.
|
||||||
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
|
>>> Diffusion backbone (model.model) converted to FP16.
|
||||||
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
>>> Projectors (image_proj_model, state_projector, action_projector) converted to FP16.
|
||||||
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
>>> Encoders (cond_stage_model, embedder) converted to FP16.
|
||||||
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
|
|
||||||
INFO:root:Loaded ViT-H-14 model config.
|
|
||||||
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
|
|
||||||
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
|
||||||
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
|
||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
|
||||||
checkpoint = torch.load(checkpoint_path, map_location=map_location)
|
|
||||||
INFO:root:Loaded ViT-H-14 model config.
|
|
||||||
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
|
||||||
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
|
||||||
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:86: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
|
||||||
state_dict = torch.load(ckpt, map_location="cpu")
|
|
||||||
>>> model checkpoint loaded.
|
|
||||||
>>> Load pre-trained model ...
|
|
||||||
INFO:root:***** Configing Data *****
|
INFO:root:***** Configing Data *****
|
||||||
>>> unitree_z1_stackbox: 1 data samples loaded.
|
>>> unitree_z1_stackbox: 1 data samples loaded.
|
||||||
>>> unitree_z1_stackbox: data stats loaded.
|
>>> unitree_z1_stackbox: data stats loaded.
|
||||||
@@ -46,69 +28,15 @@ INFO:root:***** Configing Data *****
|
|||||||
>>> unitree_g1_pack_camera: data stats loaded.
|
>>> unitree_g1_pack_camera: data stats loaded.
|
||||||
>>> unitree_g1_pack_camera: normalizer initiated.
|
>>> unitree_g1_pack_camera: normalizer initiated.
|
||||||
>>> Dataset is successfully loaded ...
|
>>> Dataset is successfully loaded ...
|
||||||
|
✓ KV fused: 66 attention layers
|
||||||
>>> Generate 16 frames under each generation ...
|
>>> Generate 16 frames under each generation ...
|
||||||
DEBUG:h5py._conv:Creating converter from 3 to 5
|
DEBUG:h5py._conv:Creating converter from 3 to 5
|
||||||
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
|
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
|
||||||
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
|
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
|
||||||
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
||||||
|
|
||||||
0%| | 0/11 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
|
|
||||||
proj = linear(q, w, b)
|
|
||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Flash attention support on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:225.)
|
|
||||||
attn_output = scaled_dot_product_attention(
|
|
||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)
|
|
||||||
0%| | 0/11 [00:00<?, ?it/s]
|
0%| | 0/11 [00:00<?, ?it/s]
|
||||||
9%|▉ | 1/11 [00:24<04:01, 24.10s/it]
|
9%|▉ | 1/11 [00:24<04:01, 24.10s/it]
|
||||||
>>> Step 0: interacting with world model ...
|
|
||||||
DEBUG:PIL.Image:Importing BlpImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing BmpImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing BufrStubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing CurImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing DcxImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing DdsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing EpsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FitsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FitsStubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FliImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FpxImagePlugin
|
|
||||||
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
|
|
||||||
DEBUG:PIL.Image:Importing FtexImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing GbrImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing GifImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing GribStubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing IcnsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing IcoImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing ImImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing ImtImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing IptcImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing JpegImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing McIdasImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing MicImagePlugin
|
|
||||||
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
|
|
||||||
DEBUG:PIL.Image:Importing MpegImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing MpoImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing MspImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PalmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PcdImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PcxImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PdfImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PixarImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PngImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PpmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PsdImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing QoiImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing SgiImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing SpiderImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing SunImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing TgaImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing TiffImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing WebPImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing WmfImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing XbmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing XpmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
|
||||||
18%|█▊ | 2/11 [00:47<03:32, 23.61s/it]
|
18%|█▊ | 2/11 [00:47<03:32, 23.61s/it]
|
||||||
27%|██▋ | 3/11 [01:10<03:08, 23.50s/it]
|
27%|██▋ | 3/11 [01:10<03:08, 23.50s/it]
|
||||||
36%|███▋ | 4/11 [01:34<02:44, 23.45s/it]
|
36%|███▋ | 4/11 [01:34<02:44, 23.45s/it]
|
||||||
@@ -141,6 +69,6 @@ DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
|||||||
>>> Step 6: generating actions ...
|
>>> Step 6: generating actions ...
|
||||||
>>> Step 6: interacting with world model ...
|
>>> Step 6: interacting with world model ...
|
||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
>>> Step 7: generating actions ...
|
>>> Step 7: generating actions ...
|
||||||
>>> Step 7: interacting with world model ...
|
>>> Step 7: interacting with world model ...
|
||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
{
|
{
|
||||||
"gt_video": "unitree_z1_dual_arm_stackbox_v2/case1/unitree_z1_dual_arm_stackbox_v2_case1.mp4",
|
"gt_video": "unitree_z1_dual_arm_stackbox_v2/case1/unitree_z1_dual_arm_stackbox_v2_case1.mp4",
|
||||||
"pred_video": "unitree_z1_dual_arm_stackbox_v2/case1/output/inference/unitree_z1_dual_arm_stackbox_v2_case1_amd.mp4",
|
"pred_video": "unitree_z1_dual_arm_stackbox_v2/case1/output/inference/5_full_fs4.mp4",
|
||||||
"psnr": 18.126776535969576
|
"psnr": 26.68301835085306
|
||||||
}
|
}
|
||||||
@@ -2,7 +2,7 @@ res_dir="unitree_z1_dual_arm_stackbox_v2/case1"
|
|||||||
dataset="unitree_z1_dual_arm_stackbox_v2"
|
dataset="unitree_z1_dual_arm_stackbox_v2"
|
||||||
|
|
||||||
{
|
{
|
||||||
time CUDA_VISIBLE_DEVICES=7 python3 scripts/evaluation/world_model_interaction.py \
|
time CUDA_VISIBLE_DEVICES=0 python3 scripts/evaluation/world_model_interaction.py \
|
||||||
--seed 123 \
|
--seed 123 \
|
||||||
--ckpt_path ckpts/unifolm_wma_dual.ckpt \
|
--ckpt_path ckpts/unifolm_wma_dual.ckpt \
|
||||||
--config configs/inference/world_model_interaction.yaml \
|
--config configs/inference/world_model_interaction.yaml \
|
||||||
@@ -20,5 +20,6 @@ dataset="unitree_z1_dual_arm_stackbox_v2"
|
|||||||
--n_iter 11 \
|
--n_iter 11 \
|
||||||
--timestep_spacing 'uniform_trailing' \
|
--timestep_spacing 'uniform_trailing' \
|
||||||
--guidance_rescale 0.7 \
|
--guidance_rescale 0.7 \
|
||||||
--perframe_ae
|
--perframe_ae \
|
||||||
|
--fast_policy_no_decode
|
||||||
} 2>&1 | tee "${res_dir}/output.log"
|
} 2>&1 | tee "${res_dir}/output.log"
|
||||||
|
|||||||
@@ -1,34 +1,16 @@
|
|||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
|
2026-02-19 19:44:07.724109: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
||||||
__import__("pkg_resources").declare_namespace(__name__)
|
2026-02-19 19:44:07.771461: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
||||||
2026-02-08 07:56:31.144789: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
2026-02-19 19:44:07.771505: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
||||||
2026-02-08 07:56:31.148256: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
2026-02-19 19:44:07.772537: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
||||||
2026-02-08 07:56:31.178870: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
2026-02-19 19:44:07.779172: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
||||||
2026-02-08 07:56:31.178898: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
||||||
2026-02-08 07:56:31.180683: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
2026-02-19 19:44:08.688975: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
||||||
2026-02-08 07:56:31.188800: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
Global seed set to 123
|
||||||
2026-02-08 07:56:31.189142: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
>>> Loading prepared model from ckpts/unifolm_wma_dual.ckpt.prepared.pt ...
|
||||||
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
>>> Prepared model loaded.
|
||||||
2026-02-08 07:56:31.810098: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
>>> Diffusion backbone (model.model) converted to FP16.
|
||||||
[rank: 0] Global seed set to 123
|
>>> Projectors (image_proj_model, state_projector, action_projector) converted to FP16.
|
||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
|
>>> Encoders (cond_stage_model, embedder) converted to FP16.
|
||||||
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
|
|
||||||
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
|
|
||||||
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
|
||||||
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
|
||||||
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
|
|
||||||
INFO:root:Loaded ViT-H-14 model config.
|
|
||||||
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
|
|
||||||
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
|
||||||
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
|
||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
|
||||||
checkpoint = torch.load(checkpoint_path, map_location=map_location)
|
|
||||||
INFO:root:Loaded ViT-H-14 model config.
|
|
||||||
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
|
||||||
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
|
||||||
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:86: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
|
||||||
state_dict = torch.load(ckpt, map_location="cpu")
|
|
||||||
>>> model checkpoint loaded.
|
|
||||||
>>> Load pre-trained model ...
|
|
||||||
INFO:root:***** Configing Data *****
|
INFO:root:***** Configing Data *****
|
||||||
>>> unitree_z1_stackbox: 1 data samples loaded.
|
>>> unitree_z1_stackbox: 1 data samples loaded.
|
||||||
>>> unitree_z1_stackbox: data stats loaded.
|
>>> unitree_z1_stackbox: data stats loaded.
|
||||||
@@ -46,69 +28,15 @@ INFO:root:***** Configing Data *****
|
|||||||
>>> unitree_g1_pack_camera: data stats loaded.
|
>>> unitree_g1_pack_camera: data stats loaded.
|
||||||
>>> unitree_g1_pack_camera: normalizer initiated.
|
>>> unitree_g1_pack_camera: normalizer initiated.
|
||||||
>>> Dataset is successfully loaded ...
|
>>> Dataset is successfully loaded ...
|
||||||
|
✓ KV fused: 66 attention layers
|
||||||
>>> Generate 16 frames under each generation ...
|
>>> Generate 16 frames under each generation ...
|
||||||
DEBUG:h5py._conv:Creating converter from 3 to 5
|
DEBUG:h5py._conv:Creating converter from 3 to 5
|
||||||
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
|
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
|
||||||
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
|
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
|
||||||
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
||||||
|
|
||||||
0%| | 0/11 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
|
|
||||||
proj = linear(q, w, b)
|
|
||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Flash attention support on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:225.)
|
|
||||||
attn_output = scaled_dot_product_attention(
|
|
||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)
|
|
||||||
0%| | 0/11 [00:00<?, ?it/s]
|
0%| | 0/11 [00:00<?, ?it/s]
|
||||||
9%|▉ | 1/11 [00:24<04:00, 24.03s/it]
|
9%|▉ | 1/11 [00:24<04:00, 24.03s/it]
|
||||||
>>> Step 0: interacting with world model ...
|
|
||||||
DEBUG:PIL.Image:Importing BlpImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing BmpImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing BufrStubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing CurImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing DcxImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing DdsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing EpsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FitsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FitsStubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FliImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FpxImagePlugin
|
|
||||||
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
|
|
||||||
DEBUG:PIL.Image:Importing FtexImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing GbrImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing GifImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing GribStubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing IcnsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing IcoImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing ImImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing ImtImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing IptcImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing JpegImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing McIdasImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing MicImagePlugin
|
|
||||||
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
|
|
||||||
DEBUG:PIL.Image:Importing MpegImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing MpoImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing MspImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PalmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PcdImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PcxImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PdfImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PixarImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PngImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PpmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PsdImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing QoiImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing SgiImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing SpiderImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing SunImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing TgaImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing TiffImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing WebPImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing WmfImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing XbmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing XpmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
|
||||||
18%|█▊ | 2/11 [00:47<03:31, 23.54s/it]
|
18%|█▊ | 2/11 [00:47<03:31, 23.54s/it]
|
||||||
27%|██▋ | 3/11 [01:10<03:07, 23.42s/it]
|
27%|██▋ | 3/11 [01:10<03:07, 23.42s/it]
|
||||||
36%|███▋ | 4/11 [01:33<02:43, 23.40s/it]
|
36%|███▋ | 4/11 [01:33<02:43, 23.40s/it]
|
||||||
@@ -141,6 +69,6 @@ DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
|||||||
>>> Step 6: generating actions ...
|
>>> Step 6: generating actions ...
|
||||||
>>> Step 6: interacting with world model ...
|
>>> Step 6: interacting with world model ...
|
||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
>>> Step 7: generating actions ...
|
>>> Step 7: generating actions ...
|
||||||
>>> Step 7: interacting with world model ...
|
>>> Step 7: interacting with world model ...
|
||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
{
|
{
|
||||||
"gt_video": "unitree_z1_dual_arm_stackbox_v2/case2/unitree_z1_dual_arm_stackbox_v2_case2.mp4",
|
"gt_video": "unitree_z1_dual_arm_stackbox_v2/case2/unitree_z1_dual_arm_stackbox_v2_case2.mp4",
|
||||||
"pred_video": "unitree_z1_dual_arm_stackbox_v2/case2/output/inference/unitree_z1_dual_arm_stackbox_v2_case2_amd.mp4",
|
"pred_video": "unitree_z1_dual_arm_stackbox_v2/case2/output/inference/15_full_fs4.mp4",
|
||||||
"psnr": 19.38130614773096
|
"psnr": 27.46347145461597
|
||||||
}
|
}
|
||||||
@@ -20,5 +20,6 @@ dataset="unitree_z1_dual_arm_stackbox_v2"
|
|||||||
--n_iter 11 \
|
--n_iter 11 \
|
||||||
--timestep_spacing 'uniform_trailing' \
|
--timestep_spacing 'uniform_trailing' \
|
||||||
--guidance_rescale 0.7 \
|
--guidance_rescale 0.7 \
|
||||||
--perframe_ae
|
--perframe_ae \
|
||||||
|
--fast_policy_no_decode
|
||||||
} 2>&1 | tee "${res_dir}/output.log"
|
} 2>&1 | tee "${res_dir}/output.log"
|
||||||
|
|||||||
@@ -1,34 +1,16 @@
|
|||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
|
2026-02-19 19:48:42.460586: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
||||||
__import__("pkg_resources").declare_namespace(__name__)
|
2026-02-19 19:48:42.508096: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
||||||
2026-02-08 07:56:04.467082: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
2026-02-19 19:48:42.508140: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
||||||
2026-02-08 07:56:04.470145: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
2026-02-19 19:48:42.509152: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
||||||
2026-02-08 07:56:04.502248: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
2026-02-19 19:48:42.515865: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
||||||
2026-02-08 07:56:04.502277: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
||||||
2026-02-08 07:56:04.504088: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
2026-02-19 19:48:43.425699: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
||||||
2026-02-08 07:56:04.512557: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
Global seed set to 123
|
||||||
2026-02-08 07:56:04.512830: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
>>> Loading prepared model from ckpts/unifolm_wma_dual.ckpt.prepared.pt ...
|
||||||
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
>>> Prepared model loaded.
|
||||||
2026-02-08 07:56:05.259641: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
>>> Diffusion backbone (model.model) converted to FP16.
|
||||||
[rank: 0] Global seed set to 123
|
>>> Projectors (image_proj_model, state_projector, action_projector) converted to FP16.
|
||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
|
>>> Encoders (cond_stage_model, embedder) converted to FP16.
|
||||||
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
|
|
||||||
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
|
|
||||||
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
|
||||||
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
|
||||||
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
|
|
||||||
INFO:root:Loaded ViT-H-14 model config.
|
|
||||||
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
|
|
||||||
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
|
||||||
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
|
||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
|
||||||
checkpoint = torch.load(checkpoint_path, map_location=map_location)
|
|
||||||
INFO:root:Loaded ViT-H-14 model config.
|
|
||||||
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
|
||||||
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
|
||||||
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:86: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
|
||||||
state_dict = torch.load(ckpt, map_location="cpu")
|
|
||||||
>>> model checkpoint loaded.
|
|
||||||
>>> Load pre-trained model ...
|
|
||||||
INFO:root:***** Configing Data *****
|
INFO:root:***** Configing Data *****
|
||||||
>>> unitree_z1_stackbox: 1 data samples loaded.
|
>>> unitree_z1_stackbox: 1 data samples loaded.
|
||||||
>>> unitree_z1_stackbox: data stats loaded.
|
>>> unitree_z1_stackbox: data stats loaded.
|
||||||
@@ -46,69 +28,15 @@ INFO:root:***** Configing Data *****
|
|||||||
>>> unitree_g1_pack_camera: data stats loaded.
|
>>> unitree_g1_pack_camera: data stats loaded.
|
||||||
>>> unitree_g1_pack_camera: normalizer initiated.
|
>>> unitree_g1_pack_camera: normalizer initiated.
|
||||||
>>> Dataset is successfully loaded ...
|
>>> Dataset is successfully loaded ...
|
||||||
|
✓ KV fused: 66 attention layers
|
||||||
>>> Generate 16 frames under each generation ...
|
>>> Generate 16 frames under each generation ...
|
||||||
DEBUG:h5py._conv:Creating converter from 3 to 5
|
DEBUG:h5py._conv:Creating converter from 3 to 5
|
||||||
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
|
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
|
||||||
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
|
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
|
||||||
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
||||||
|
|
||||||
0%| | 0/11 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
|
|
||||||
proj = linear(q, w, b)
|
|
||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Flash attention support on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:225.)
|
|
||||||
attn_output = scaled_dot_product_attention(
|
|
||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)
|
|
||||||
0%| | 0/11 [00:00<?, ?it/s]
|
0%| | 0/11 [00:00<?, ?it/s]
|
||||||
9%|▉ | 1/11 [00:24<04:00, 24.07s/it]
|
9%|▉ | 1/11 [00:24<04:00, 24.07s/it]
|
||||||
>>> Step 0: interacting with world model ...
|
|
||||||
DEBUG:PIL.Image:Importing BlpImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing BmpImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing BufrStubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing CurImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing DcxImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing DdsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing EpsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FitsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FitsStubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FliImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FpxImagePlugin
|
|
||||||
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
|
|
||||||
DEBUG:PIL.Image:Importing FtexImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing GbrImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing GifImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing GribStubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing IcnsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing IcoImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing ImImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing ImtImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing IptcImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing JpegImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing McIdasImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing MicImagePlugin
|
|
||||||
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
|
|
||||||
DEBUG:PIL.Image:Importing MpegImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing MpoImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing MspImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PalmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PcdImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PcxImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PdfImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PixarImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PngImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PpmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PsdImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing QoiImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing SgiImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing SpiderImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing SunImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing TgaImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing TiffImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing WebPImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing WmfImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing XbmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing XpmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
|
||||||
18%|█▊ | 2/11 [00:47<03:32, 23.62s/it]
|
18%|█▊ | 2/11 [00:47<03:32, 23.62s/it]
|
||||||
27%|██▋ | 3/11 [01:10<03:08, 23.51s/it]
|
27%|██▋ | 3/11 [01:10<03:08, 23.51s/it]
|
||||||
36%|███▋ | 4/11 [01:34<02:44, 23.46s/it]
|
36%|███▋ | 4/11 [01:34<02:44, 23.46s/it]
|
||||||
@@ -141,6 +69,6 @@ DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
|||||||
>>> Step 6: generating actions ...
|
>>> Step 6: generating actions ...
|
||||||
>>> Step 6: interacting with world model ...
|
>>> Step 6: interacting with world model ...
|
||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
>>> Step 7: generating actions ...
|
>>> Step 7: generating actions ...
|
||||||
>>> Step 7: interacting with world model ...
|
>>> Step 7: interacting with world model ...
|
||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
{
|
{
|
||||||
"gt_video": "unitree_z1_dual_arm_stackbox_v2/case3/unitree_z1_dual_arm_stackbox_v2_case3.mp4",
|
"gt_video": "unitree_z1_dual_arm_stackbox_v2/case3/unitree_z1_dual_arm_stackbox_v2_case3.mp4",
|
||||||
"pred_video": "unitree_z1_dual_arm_stackbox_v2/case3/output/inference/unitree_z1_dual_arm_stackbox_v2_case3_amd.mp4",
|
"pred_video": "unitree_z1_dual_arm_stackbox_v2/case3/output/inference/25_full_fs4.mp4",
|
||||||
"psnr": 18.74462122425683
|
"psnr": 28.604047286947512
|
||||||
}
|
}
|
||||||
@@ -20,5 +20,6 @@ dataset="unitree_z1_dual_arm_stackbox_v2"
|
|||||||
--n_iter 11 \
|
--n_iter 11 \
|
||||||
--timestep_spacing 'uniform_trailing' \
|
--timestep_spacing 'uniform_trailing' \
|
||||||
--guidance_rescale 0.7 \
|
--guidance_rescale 0.7 \
|
||||||
--perframe_ae
|
--perframe_ae \
|
||||||
|
--fast_policy_no_decode
|
||||||
} 2>&1 | tee "${res_dir}/output.log"
|
} 2>&1 | tee "${res_dir}/output.log"
|
||||||
|
|||||||
@@ -1,34 +1,16 @@
|
|||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
|
2026-02-19 19:53:17.574354: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
||||||
__import__("pkg_resources").declare_namespace(__name__)
|
2026-02-19 19:53:17.621335: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
||||||
2026-02-08 08:04:16.104516: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
2026-02-19 19:53:17.621388: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
||||||
2026-02-08 08:04:16.109112: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
2026-02-19 19:53:17.622415: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
||||||
2026-02-08 08:04:16.138703: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
2026-02-19 19:53:17.629050: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
||||||
2026-02-08 08:04:16.138737: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
||||||
2026-02-08 08:04:16.140302: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
2026-02-19 19:53:18.537233: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
||||||
2026-02-08 08:04:16.147672: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
|
||||||
2026-02-08 08:04:16.147903: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
|
||||||
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
|
||||||
2026-02-08 08:04:17.363218: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
|
||||||
Global seed set to 123
|
Global seed set to 123
|
||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
|
>>> Loading prepared model from ckpts/unifolm_wma_dual.ckpt.prepared.pt ...
|
||||||
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
|
>>> Prepared model loaded.
|
||||||
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
|
>>> Diffusion backbone (model.model) converted to FP16.
|
||||||
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
>>> Projectors (image_proj_model, state_projector, action_projector) converted to FP16.
|
||||||
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
>>> Encoders (cond_stage_model, embedder) converted to FP16.
|
||||||
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
|
|
||||||
INFO:root:Loaded ViT-H-14 model config.
|
|
||||||
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
|
|
||||||
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
|
||||||
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
|
||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
|
||||||
checkpoint = torch.load(checkpoint_path, map_location=map_location)
|
|
||||||
INFO:root:Loaded ViT-H-14 model config.
|
|
||||||
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
|
||||||
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
|
||||||
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:86: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
|
||||||
state_dict = torch.load(ckpt, map_location="cpu")
|
|
||||||
>>> model checkpoint loaded.
|
|
||||||
>>> Load pre-trained model ...
|
|
||||||
INFO:root:***** Configing Data *****
|
INFO:root:***** Configing Data *****
|
||||||
>>> unitree_z1_stackbox: 1 data samples loaded.
|
>>> unitree_z1_stackbox: 1 data samples loaded.
|
||||||
>>> unitree_z1_stackbox: data stats loaded.
|
>>> unitree_z1_stackbox: data stats loaded.
|
||||||
@@ -46,69 +28,15 @@ INFO:root:***** Configing Data *****
|
|||||||
>>> unitree_g1_pack_camera: data stats loaded.
|
>>> unitree_g1_pack_camera: data stats loaded.
|
||||||
>>> unitree_g1_pack_camera: normalizer initiated.
|
>>> unitree_g1_pack_camera: normalizer initiated.
|
||||||
>>> Dataset is successfully loaded ...
|
>>> Dataset is successfully loaded ...
|
||||||
|
✓ KV fused: 66 attention layers
|
||||||
>>> Generate 16 frames under each generation ...
|
>>> Generate 16 frames under each generation ...
|
||||||
DEBUG:h5py._conv:Creating converter from 3 to 5
|
DEBUG:h5py._conv:Creating converter from 3 to 5
|
||||||
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
|
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
|
||||||
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
|
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
|
||||||
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
||||||
|
|
||||||
0%| | 0/11 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
|
|
||||||
proj = linear(q, w, b)
|
|
||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Flash attention support on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:225.)
|
|
||||||
attn_output = scaled_dot_product_attention(
|
|
||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)
|
|
||||||
0%| | 0/11 [00:00<?, ?it/s]
|
0%| | 0/11 [00:00<?, ?it/s]
|
||||||
9%|▉ | 1/11 [00:24<04:00, 24.09s/it]
|
9%|▉ | 1/11 [00:24<04:00, 24.09s/it]
|
||||||
>>> Step 0: interacting with world model ...
|
|
||||||
DEBUG:PIL.Image:Importing BlpImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing BmpImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing BufrStubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing CurImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing DcxImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing DdsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing EpsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FitsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FitsStubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FliImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FpxImagePlugin
|
|
||||||
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
|
|
||||||
DEBUG:PIL.Image:Importing FtexImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing GbrImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing GifImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing GribStubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing IcnsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing IcoImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing ImImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing ImtImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing IptcImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing JpegImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing McIdasImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing MicImagePlugin
|
|
||||||
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
|
|
||||||
DEBUG:PIL.Image:Importing MpegImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing MpoImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing MspImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PalmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PcdImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PcxImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PdfImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PixarImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PngImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PpmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PsdImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing QoiImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing SgiImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing SpiderImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing SunImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing TgaImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing TiffImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing WebPImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing WmfImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing XbmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing XpmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
|
||||||
18%|█▊ | 2/11 [00:47<03:32, 23.62s/it]
|
18%|█▊ | 2/11 [00:47<03:32, 23.62s/it]
|
||||||
27%|██▋ | 3/11 [01:10<03:07, 23.49s/it]
|
27%|██▋ | 3/11 [01:10<03:07, 23.49s/it]
|
||||||
36%|███▋ | 4/11 [01:34<02:44, 23.47s/it]
|
36%|███▋ | 4/11 [01:34<02:44, 23.47s/it]
|
||||||
@@ -141,6 +69,6 @@ DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
|||||||
>>> Step 6: generating actions ...
|
>>> Step 6: generating actions ...
|
||||||
>>> Step 6: interacting with world model ...
|
>>> Step 6: interacting with world model ...
|
||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
>>> Step 7: generating actions ...
|
>>> Step 7: generating actions ...
|
||||||
>>> Step 7: interacting with world model ...
|
>>> Step 7: interacting with world model ...
|
||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
{
|
{
|
||||||
"gt_video": "unitree_z1_dual_arm_stackbox_v2/case4/unitree_z1_dual_arm_stackbox_v2_case4.mp4",
|
"gt_video": "unitree_z1_dual_arm_stackbox_v2/case4/unitree_z1_dual_arm_stackbox_v2_case4.mp4",
|
||||||
"pred_video": "unitree_z1_dual_arm_stackbox_v2/case4/output/inference/unitree_z1_dual_arm_stackbox_v2_case4_amd.mp4",
|
"pred_video": "unitree_z1_dual_arm_stackbox_v2/case4/output/inference/35_full_fs4.mp4",
|
||||||
"psnr": 19.526448380726254
|
"psnr": 25.578757174083307
|
||||||
}
|
}
|
||||||
@@ -2,7 +2,7 @@ res_dir="unitree_z1_dual_arm_stackbox_v2/case4"
|
|||||||
dataset="unitree_z1_dual_arm_stackbox_v2"
|
dataset="unitree_z1_dual_arm_stackbox_v2"
|
||||||
|
|
||||||
{
|
{
|
||||||
time CUDA_VISIBLE_DEVICES=6 python3 scripts/evaluation/world_model_interaction.py \
|
time CUDA_VISIBLE_DEVICES=0 python3 scripts/evaluation/world_model_interaction.py \
|
||||||
--seed 123 \
|
--seed 123 \
|
||||||
--ckpt_path ckpts/unifolm_wma_dual.ckpt \
|
--ckpt_path ckpts/unifolm_wma_dual.ckpt \
|
||||||
--config configs/inference/world_model_interaction.yaml \
|
--config configs/inference/world_model_interaction.yaml \
|
||||||
@@ -20,5 +20,6 @@ dataset="unitree_z1_dual_arm_stackbox_v2"
|
|||||||
--n_iter 11 \
|
--n_iter 11 \
|
||||||
--timestep_spacing 'uniform_trailing' \
|
--timestep_spacing 'uniform_trailing' \
|
||||||
--guidance_rescale 0.7 \
|
--guidance_rescale 0.7 \
|
||||||
--perframe_ae
|
--perframe_ae \
|
||||||
|
--fast_policy_no_decode
|
||||||
} 2>&1 | tee "${res_dir}/output.log"
|
} 2>&1 | tee "${res_dir}/output.log"
|
||||||
|
|||||||
@@ -1,34 +1,16 @@
|
|||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
|
2026-02-19 19:57:52.488339: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
||||||
__import__("pkg_resources").declare_namespace(__name__)
|
2026-02-19 19:57:52.536176: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
||||||
2026-02-08 08:12:47.424053: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
2026-02-19 19:57:52.536222: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
||||||
2026-02-08 08:12:47.427280: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
2026-02-19 19:57:52.537285: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
||||||
2026-02-08 08:12:47.458253: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
2026-02-19 19:57:52.544051: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
||||||
2026-02-08 08:12:47.458288: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
||||||
2026-02-08 08:12:47.462758: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
2026-02-19 19:57:53.469912: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
||||||
2026-02-08 08:12:47.518283: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
|
||||||
2026-02-08 08:12:47.518566: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
|
||||||
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
|
||||||
2026-02-08 08:12:48.593011: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
|
||||||
Global seed set to 123
|
Global seed set to 123
|
||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
|
>>> Loading prepared model from ckpts/unifolm_wma_dual.ckpt.prepared.pt ...
|
||||||
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
|
>>> Prepared model loaded.
|
||||||
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
|
>>> Diffusion backbone (model.model) converted to FP16.
|
||||||
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
>>> Projectors (image_proj_model, state_projector, action_projector) converted to FP16.
|
||||||
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
>>> Encoders (cond_stage_model, embedder) converted to FP16.
|
||||||
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
|
|
||||||
INFO:root:Loaded ViT-H-14 model config.
|
|
||||||
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
|
|
||||||
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
|
||||||
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
|
||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
|
||||||
checkpoint = torch.load(checkpoint_path, map_location=map_location)
|
|
||||||
INFO:root:Loaded ViT-H-14 model config.
|
|
||||||
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
|
||||||
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
|
||||||
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:86: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
|
||||||
state_dict = torch.load(ckpt, map_location="cpu")
|
|
||||||
>>> model checkpoint loaded.
|
|
||||||
>>> Load pre-trained model ...
|
|
||||||
INFO:root:***** Configing Data *****
|
INFO:root:***** Configing Data *****
|
||||||
>>> unitree_z1_stackbox: 1 data samples loaded.
|
>>> unitree_z1_stackbox: 1 data samples loaded.
|
||||||
>>> unitree_z1_stackbox: data stats loaded.
|
>>> unitree_z1_stackbox: data stats loaded.
|
||||||
@@ -46,69 +28,15 @@ INFO:root:***** Configing Data *****
|
|||||||
>>> unitree_g1_pack_camera: data stats loaded.
|
>>> unitree_g1_pack_camera: data stats loaded.
|
||||||
>>> unitree_g1_pack_camera: normalizer initiated.
|
>>> unitree_g1_pack_camera: normalizer initiated.
|
||||||
>>> Dataset is successfully loaded ...
|
>>> Dataset is successfully loaded ...
|
||||||
|
✓ KV fused: 66 attention layers
|
||||||
>>> Generate 16 frames under each generation ...
|
>>> Generate 16 frames under each generation ...
|
||||||
DEBUG:h5py._conv:Creating converter from 3 to 5
|
DEBUG:h5py._conv:Creating converter from 3 to 5
|
||||||
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
|
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
|
||||||
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
|
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
|
||||||
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
||||||
|
|
||||||
0%| | 0/12 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
|
|
||||||
proj = linear(q, w, b)
|
|
||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Flash attention support on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:225.)
|
|
||||||
attn_output = scaled_dot_product_attention(
|
|
||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)
|
|
||||||
0%| | 0/12 [00:00<?, ?it/s]
|
0%| | 0/12 [00:00<?, ?it/s]
|
||||||
8%|▊ | 1/12 [00:24<04:24, 24.06s/it]
|
8%|▊ | 1/12 [00:24<04:24, 24.06s/it]
|
||||||
>>> Step 0: interacting with world model ...
|
|
||||||
DEBUG:PIL.Image:Importing BlpImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing BmpImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing BufrStubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing CurImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing DcxImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing DdsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing EpsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FitsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FitsStubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FliImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FpxImagePlugin
|
|
||||||
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
|
|
||||||
DEBUG:PIL.Image:Importing FtexImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing GbrImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing GifImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing GribStubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing IcnsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing IcoImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing ImImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing ImtImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing IptcImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing JpegImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing McIdasImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing MicImagePlugin
|
|
||||||
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
|
|
||||||
DEBUG:PIL.Image:Importing MpegImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing MpoImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing MspImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PalmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PcdImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PcxImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PdfImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PixarImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PngImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PpmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PsdImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing QoiImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing SgiImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing SpiderImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing SunImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing TgaImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing TiffImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing WebPImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing WmfImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing XbmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing XpmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
|
||||||
17%|█▋ | 2/12 [00:47<03:55, 23.56s/it]
|
17%|█▋ | 2/12 [00:47<03:55, 23.56s/it]
|
||||||
25%|██▌ | 3/12 [01:10<03:31, 23.46s/it]
|
25%|██▌ | 3/12 [01:10<03:31, 23.46s/it]
|
||||||
33%|███▎ | 4/12 [01:33<03:07, 23.43s/it]
|
33%|███▎ | 4/12 [01:33<03:07, 23.43s/it]
|
||||||
@@ -144,6 +72,6 @@ DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
|||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
>>> Step 7: generating actions ...
|
>>> Step 7: generating actions ...
|
||||||
>>> Step 7: interacting with world model ...
|
>>> Step 7: interacting with world model ...
|
||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
>>> Step 8: generating actions ...
|
>>> Step 8: generating actions ...
|
||||||
>>> Step 8: interacting with world model ...
|
>>> Step 8: interacting with world model ...
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
{
|
{
|
||||||
"gt_video": "unitree_z1_stackbox/case1/unitree_z1_stackbox_case1.mp4",
|
"gt_video": "unitree_z1_stackbox/case1/unitree_z1_stackbox_case1.mp4",
|
||||||
"pred_video": "unitree_z1_stackbox/case1/output/inference/unitree_z1_stackbox_case1_amd.mp4",
|
"pred_video": "unitree_z1_stackbox/case1/output/inference/5_full_fs4.mp4",
|
||||||
"psnr": 19.81391789862606
|
"psnr": 46.05271283048069
|
||||||
}
|
}
|
||||||
@@ -2,7 +2,7 @@ res_dir="unitree_z1_stackbox/case1"
|
|||||||
dataset="unitree_z1_stackbox"
|
dataset="unitree_z1_stackbox"
|
||||||
|
|
||||||
{
|
{
|
||||||
time CUDA_VISIBLE_DEVICES=5 python3 scripts/evaluation/world_model_interaction.py \
|
time CUDA_VISIBLE_DEVICES=0 python3 scripts/evaluation/world_model_interaction.py \
|
||||||
--seed 123 \
|
--seed 123 \
|
||||||
--ckpt_path ckpts/unifolm_wma_dual.ckpt \
|
--ckpt_path ckpts/unifolm_wma_dual.ckpt \
|
||||||
--config configs/inference/world_model_interaction.yaml \
|
--config configs/inference/world_model_interaction.yaml \
|
||||||
@@ -20,5 +20,6 @@ dataset="unitree_z1_stackbox"
|
|||||||
--n_iter 12 \
|
--n_iter 12 \
|
||||||
--timestep_spacing 'uniform_trailing' \
|
--timestep_spacing 'uniform_trailing' \
|
||||||
--guidance_rescale 0.7 \
|
--guidance_rescale 0.7 \
|
||||||
--perframe_ae
|
--perframe_ae \
|
||||||
|
--fast_policy_no_decode
|
||||||
} 2>&1 | tee "${res_dir}/output.log"
|
} 2>&1 | tee "${res_dir}/output.log"
|
||||||
|
|||||||
@@ -1,34 +1,16 @@
|
|||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
|
2026-02-19 20:02:50.975402: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
||||||
__import__("pkg_resources").declare_namespace(__name__)
|
2026-02-19 20:02:51.023211: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
||||||
2026-02-08 08:15:49.934949: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
2026-02-19 20:02:51.023253: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
||||||
2026-02-08 08:15:49.937974: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
2026-02-19 20:02:51.024328: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
||||||
2026-02-08 08:15:49.969069: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
2026-02-19 20:02:51.031176: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
||||||
2026-02-08 08:15:49.969100: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
||||||
2026-02-08 08:15:49.970909: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
2026-02-19 20:02:51.947400: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
||||||
2026-02-08 08:15:49.979005: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
Global seed set to 123
|
||||||
2026-02-08 08:15:49.979255: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
>>> Loading prepared model from ckpts/unifolm_wma_dual.ckpt.prepared.pt ...
|
||||||
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
>>> Prepared model loaded.
|
||||||
2026-02-08 08:15:50.597743: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
>>> Diffusion backbone (model.model) converted to FP16.
|
||||||
[rank: 0] Global seed set to 123
|
>>> Projectors (image_proj_model, state_projector, action_projector) converted to FP16.
|
||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
|
>>> Encoders (cond_stage_model, embedder) converted to FP16.
|
||||||
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
|
|
||||||
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
|
|
||||||
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
|
||||||
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
|
||||||
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
|
|
||||||
INFO:root:Loaded ViT-H-14 model config.
|
|
||||||
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
|
|
||||||
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
|
||||||
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
|
||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
|
||||||
checkpoint = torch.load(checkpoint_path, map_location=map_location)
|
|
||||||
INFO:root:Loaded ViT-H-14 model config.
|
|
||||||
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
|
||||||
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
|
||||||
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:86: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
|
||||||
state_dict = torch.load(ckpt, map_location="cpu")
|
|
||||||
>>> model checkpoint loaded.
|
|
||||||
>>> Load pre-trained model ...
|
|
||||||
INFO:root:***** Configing Data *****
|
INFO:root:***** Configing Data *****
|
||||||
>>> unitree_z1_stackbox: 1 data samples loaded.
|
>>> unitree_z1_stackbox: 1 data samples loaded.
|
||||||
>>> unitree_z1_stackbox: data stats loaded.
|
>>> unitree_z1_stackbox: data stats loaded.
|
||||||
@@ -46,69 +28,15 @@ INFO:root:***** Configing Data *****
|
|||||||
>>> unitree_g1_pack_camera: data stats loaded.
|
>>> unitree_g1_pack_camera: data stats loaded.
|
||||||
>>> unitree_g1_pack_camera: normalizer initiated.
|
>>> unitree_g1_pack_camera: normalizer initiated.
|
||||||
>>> Dataset is successfully loaded ...
|
>>> Dataset is successfully loaded ...
|
||||||
|
✓ KV fused: 66 attention layers
|
||||||
>>> Generate 16 frames under each generation ...
|
>>> Generate 16 frames under each generation ...
|
||||||
DEBUG:h5py._conv:Creating converter from 3 to 5
|
DEBUG:h5py._conv:Creating converter from 3 to 5
|
||||||
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
|
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
|
||||||
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
|
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
|
||||||
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
||||||
|
|
||||||
0%| | 0/12 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
|
|
||||||
proj = linear(q, w, b)
|
|
||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Flash attention support on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:225.)
|
|
||||||
attn_output = scaled_dot_product_attention(
|
|
||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)
|
|
||||||
0%| | 0/12 [00:00<?, ?it/s]
|
0%| | 0/12 [00:00<?, ?it/s]
|
||||||
8%|▊ | 1/12 [00:24<04:24, 24.08s/it]
|
8%|▊ | 1/12 [00:24<04:24, 24.08s/it]
|
||||||
>>> Step 0: interacting with world model ...
|
|
||||||
DEBUG:PIL.Image:Importing BlpImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing BmpImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing BufrStubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing CurImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing DcxImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing DdsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing EpsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FitsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FitsStubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FliImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FpxImagePlugin
|
|
||||||
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
|
|
||||||
DEBUG:PIL.Image:Importing FtexImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing GbrImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing GifImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing GribStubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing IcnsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing IcoImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing ImImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing ImtImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing IptcImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing JpegImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing McIdasImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing MicImagePlugin
|
|
||||||
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
|
|
||||||
DEBUG:PIL.Image:Importing MpegImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing MpoImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing MspImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PalmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PcdImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PcxImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PdfImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PixarImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PngImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PpmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PsdImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing QoiImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing SgiImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing SpiderImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing SunImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing TgaImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing TiffImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing WebPImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing WmfImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing XbmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing XpmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
|
||||||
17%|█▋ | 2/12 [00:47<03:56, 23.62s/it]
|
17%|█▋ | 2/12 [00:47<03:56, 23.62s/it]
|
||||||
25%|██▌ | 3/12 [01:10<03:31, 23.51s/it]
|
25%|██▌ | 3/12 [01:10<03:31, 23.51s/it]
|
||||||
33%|███▎ | 4/12 [01:34<03:07, 23.48s/it]
|
33%|███▎ | 4/12 [01:34<03:07, 23.48s/it]
|
||||||
@@ -144,6 +72,6 @@ DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
|||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
>>> Step 7: generating actions ...
|
>>> Step 7: generating actions ...
|
||||||
>>> Step 7: interacting with world model ...
|
>>> Step 7: interacting with world model ...
|
||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
>>> Step 8: generating actions ...
|
>>> Step 8: generating actions ...
|
||||||
>>> Step 8: interacting with world model ...
|
>>> Step 8: interacting with world model ...
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
{
|
{
|
||||||
"gt_video": "unitree_z1_stackbox/case2/unitree_z1_stackbox_case2.mp4",
|
"gt_video": "unitree_z1_stackbox/case2/unitree_z1_stackbox_case2.mp4",
|
||||||
"pred_video": "unitree_z1_stackbox/case2/output/inference/unitree_z1_stackbox_case2_amd.mp4",
|
"pred_video": "unitree_z1_stackbox/case2/output/inference/15_full_fs4.mp4",
|
||||||
"psnr": 21.083821459054743
|
"psnr": 43.005233352958804
|
||||||
}
|
}
|
||||||
@@ -20,5 +20,6 @@ dataset="unitree_z1_stackbox"
|
|||||||
--n_iter 12 \
|
--n_iter 12 \
|
||||||
--timestep_spacing 'uniform_trailing' \
|
--timestep_spacing 'uniform_trailing' \
|
||||||
--guidance_rescale 0.7 \
|
--guidance_rescale 0.7 \
|
||||||
--perframe_ae
|
--perframe_ae \
|
||||||
|
--fast_policy_no_decode
|
||||||
} 2>&1 | tee "${res_dir}/output.log"
|
} 2>&1 | tee "${res_dir}/output.log"
|
||||||
|
|||||||
@@ -1,34 +1,16 @@
|
|||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
|
2026-02-19 20:07:49.410622: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
||||||
__import__("pkg_resources").declare_namespace(__name__)
|
2026-02-19 20:07:49.457896: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
||||||
2026-02-08 08:16:22.299521: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
2026-02-19 20:07:49.457948: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
||||||
2026-02-08 08:16:22.302545: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
2026-02-19 20:07:49.458967: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
||||||
2026-02-08 08:16:22.335354: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
2026-02-19 20:07:49.465632: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
||||||
2026-02-08 08:16:22.335389: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
||||||
2026-02-08 08:16:22.337179: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
2026-02-19 20:07:50.373326: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
||||||
2026-02-08 08:16:22.345296: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
Global seed set to 123
|
||||||
2026-02-08 08:16:22.345548: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
>>> Loading prepared model from ckpts/unifolm_wma_dual.ckpt.prepared.pt ...
|
||||||
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
>>> Prepared model loaded.
|
||||||
2026-02-08 08:16:23.008743: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
>>> Diffusion backbone (model.model) converted to FP16.
|
||||||
[rank: 0] Global seed set to 123
|
>>> Projectors (image_proj_model, state_projector, action_projector) converted to FP16.
|
||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
|
>>> Encoders (cond_stage_model, embedder) converted to FP16.
|
||||||
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
|
|
||||||
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
|
|
||||||
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
|
||||||
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
|
||||||
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
|
|
||||||
INFO:root:Loaded ViT-H-14 model config.
|
|
||||||
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
|
|
||||||
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
|
||||||
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
|
||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
|
||||||
checkpoint = torch.load(checkpoint_path, map_location=map_location)
|
|
||||||
INFO:root:Loaded ViT-H-14 model config.
|
|
||||||
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
|
||||||
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
|
||||||
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:86: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
|
||||||
state_dict = torch.load(ckpt, map_location="cpu")
|
|
||||||
>>> model checkpoint loaded.
|
|
||||||
>>> Load pre-trained model ...
|
|
||||||
INFO:root:***** Configing Data *****
|
INFO:root:***** Configing Data *****
|
||||||
>>> unitree_z1_stackbox: 1 data samples loaded.
|
>>> unitree_z1_stackbox: 1 data samples loaded.
|
||||||
>>> unitree_z1_stackbox: data stats loaded.
|
>>> unitree_z1_stackbox: data stats loaded.
|
||||||
@@ -46,69 +28,15 @@ INFO:root:***** Configing Data *****
|
|||||||
>>> unitree_g1_pack_camera: data stats loaded.
|
>>> unitree_g1_pack_camera: data stats loaded.
|
||||||
>>> unitree_g1_pack_camera: normalizer initiated.
|
>>> unitree_g1_pack_camera: normalizer initiated.
|
||||||
>>> Dataset is successfully loaded ...
|
>>> Dataset is successfully loaded ...
|
||||||
|
✓ KV fused: 66 attention layers
|
||||||
>>> Generate 16 frames under each generation ...
|
>>> Generate 16 frames under each generation ...
|
||||||
DEBUG:h5py._conv:Creating converter from 3 to 5
|
DEBUG:h5py._conv:Creating converter from 3 to 5
|
||||||
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
|
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
|
||||||
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
|
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
|
||||||
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
||||||
|
|
||||||
0%| | 0/12 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
|
|
||||||
proj = linear(q, w, b)
|
|
||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Flash attention support on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:225.)
|
|
||||||
attn_output = scaled_dot_product_attention(
|
|
||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)
|
|
||||||
0%| | 0/12 [00:00<?, ?it/s]
|
0%| | 0/12 [00:00<?, ?it/s]
|
||||||
8%|▊ | 1/12 [00:24<04:25, 24.17s/it]
|
8%|▊ | 1/12 [00:24<04:25, 24.17s/it]
|
||||||
>>> Step 0: interacting with world model ...
|
|
||||||
DEBUG:PIL.Image:Importing BlpImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing BmpImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing BufrStubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing CurImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing DcxImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing DdsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing EpsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FitsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FitsStubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FliImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FpxImagePlugin
|
|
||||||
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
|
|
||||||
DEBUG:PIL.Image:Importing FtexImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing GbrImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing GifImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing GribStubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing IcnsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing IcoImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing ImImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing ImtImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing IptcImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing JpegImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing McIdasImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing MicImagePlugin
|
|
||||||
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
|
|
||||||
DEBUG:PIL.Image:Importing MpegImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing MpoImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing MspImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PalmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PcdImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PcxImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PdfImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PixarImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PngImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PpmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PsdImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing QoiImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing SgiImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing SpiderImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing SunImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing TgaImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing TiffImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing WebPImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing WmfImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing XbmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing XpmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
|
||||||
17%|█▋ | 2/12 [00:47<03:56, 23.64s/it]
|
17%|█▋ | 2/12 [00:47<03:56, 23.64s/it]
|
||||||
25%|██▌ | 3/12 [01:10<03:31, 23.53s/it]
|
25%|██▌ | 3/12 [01:10<03:31, 23.53s/it]
|
||||||
33%|███▎ | 4/12 [01:34<03:07, 23.49s/it]
|
33%|███▎ | 4/12 [01:34<03:07, 23.49s/it]
|
||||||
@@ -144,6 +72,6 @@ DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
|||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
>>> Step 7: generating actions ...
|
>>> Step 7: generating actions ...
|
||||||
>>> Step 7: interacting with world model ...
|
>>> Step 7: interacting with world model ...
|
||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
>>> Step 8: generating actions ...
|
>>> Step 8: generating actions ...
|
||||||
>>> Step 8: interacting with world model ...
|
>>> Step 8: interacting with world model ...
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
{
|
{
|
||||||
"gt_video": "unitree_z1_stackbox/case3/unitree_z1_stackbox_case3.mp4",
|
"gt_video": "unitree_z1_stackbox/case3/unitree_z1_stackbox_case3.mp4",
|
||||||
"pred_video": "unitree_z1_stackbox/case3/output/inference/unitree_z1_stackbox_case3_amd.mp4",
|
"pred_video": "unitree_z1_stackbox/case3/output/inference/25_full_fs4.mp4",
|
||||||
"psnr": 21.322784880212172
|
"psnr": 49.489774674892764
|
||||||
}
|
}
|
||||||
@@ -20,5 +20,6 @@ dataset="unitree_z1_stackbox"
|
|||||||
--n_iter 12 \
|
--n_iter 12 \
|
||||||
--timestep_spacing 'uniform_trailing' \
|
--timestep_spacing 'uniform_trailing' \
|
||||||
--guidance_rescale 0.7 \
|
--guidance_rescale 0.7 \
|
||||||
--perframe_ae
|
--perframe_ae \
|
||||||
|
--fast_policy_no_decode
|
||||||
} 2>&1 | tee "${res_dir}/output.log"
|
} 2>&1 | tee "${res_dir}/output.log"
|
||||||
|
|||||||
@@ -1,34 +1,16 @@
|
|||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
|
2026-02-19 20:12:48.029611: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
||||||
__import__("pkg_resources").declare_namespace(__name__)
|
2026-02-19 20:12:48.076914: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
||||||
2026-02-08 08:25:54.657305: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
2026-02-19 20:12:48.076957: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
||||||
2026-02-08 08:25:54.660628: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
2026-02-19 20:12:48.077981: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
||||||
2026-02-08 08:25:54.691237: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
2026-02-19 20:12:48.084620: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
||||||
2026-02-08 08:25:54.691275: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
||||||
2026-02-08 08:25:54.693046: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
2026-02-19 20:12:49.004753: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
||||||
2026-02-08 08:25:54.701142: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
|
||||||
2026-02-08 08:25:54.701413: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
|
||||||
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
|
||||||
2026-02-08 08:25:55.801367: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
|
||||||
Global seed set to 123
|
Global seed set to 123
|
||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
|
>>> Loading prepared model from ckpts/unifolm_wma_dual.ckpt.prepared.pt ...
|
||||||
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
|
>>> Prepared model loaded.
|
||||||
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
|
>>> Diffusion backbone (model.model) converted to FP16.
|
||||||
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
>>> Projectors (image_proj_model, state_projector, action_projector) converted to FP16.
|
||||||
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
>>> Encoders (cond_stage_model, embedder) converted to FP16.
|
||||||
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
|
|
||||||
INFO:root:Loaded ViT-H-14 model config.
|
|
||||||
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
|
|
||||||
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
|
||||||
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
|
||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
|
||||||
checkpoint = torch.load(checkpoint_path, map_location=map_location)
|
|
||||||
INFO:root:Loaded ViT-H-14 model config.
|
|
||||||
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
|
||||||
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
|
||||||
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:86: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
|
||||||
state_dict = torch.load(ckpt, map_location="cpu")
|
|
||||||
>>> model checkpoint loaded.
|
|
||||||
>>> Load pre-trained model ...
|
|
||||||
INFO:root:***** Configing Data *****
|
INFO:root:***** Configing Data *****
|
||||||
>>> unitree_z1_stackbox: 1 data samples loaded.
|
>>> unitree_z1_stackbox: 1 data samples loaded.
|
||||||
>>> unitree_z1_stackbox: data stats loaded.
|
>>> unitree_z1_stackbox: data stats loaded.
|
||||||
@@ -46,69 +28,15 @@ INFO:root:***** Configing Data *****
|
|||||||
>>> unitree_g1_pack_camera: data stats loaded.
|
>>> unitree_g1_pack_camera: data stats loaded.
|
||||||
>>> unitree_g1_pack_camera: normalizer initiated.
|
>>> unitree_g1_pack_camera: normalizer initiated.
|
||||||
>>> Dataset is successfully loaded ...
|
>>> Dataset is successfully loaded ...
|
||||||
|
✓ KV fused: 66 attention layers
|
||||||
>>> Generate 16 frames under each generation ...
|
>>> Generate 16 frames under each generation ...
|
||||||
DEBUG:h5py._conv:Creating converter from 3 to 5
|
DEBUG:h5py._conv:Creating converter from 3 to 5
|
||||||
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
|
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
|
||||||
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
|
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
|
||||||
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
||||||
|
|
||||||
0%| | 0/12 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
|
|
||||||
proj = linear(q, w, b)
|
|
||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Flash attention support on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:225.)
|
|
||||||
attn_output = scaled_dot_product_attention(
|
|
||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)
|
|
||||||
0%| | 0/12 [00:00<?, ?it/s]
|
0%| | 0/12 [00:00<?, ?it/s]
|
||||||
8%|▊ | 1/12 [00:24<04:24, 24.06s/it]
|
8%|▊ | 1/12 [00:24<04:24, 24.06s/it]
|
||||||
>>> Step 0: interacting with world model ...
|
|
||||||
DEBUG:PIL.Image:Importing BlpImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing BmpImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing BufrStubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing CurImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing DcxImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing DdsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing EpsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FitsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FitsStubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FliImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FpxImagePlugin
|
|
||||||
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
|
|
||||||
DEBUG:PIL.Image:Importing FtexImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing GbrImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing GifImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing GribStubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing IcnsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing IcoImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing ImImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing ImtImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing IptcImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing JpegImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing McIdasImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing MicImagePlugin
|
|
||||||
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
|
|
||||||
DEBUG:PIL.Image:Importing MpegImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing MpoImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing MspImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PalmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PcdImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PcxImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PdfImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PixarImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PngImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PpmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PsdImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing QoiImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing SgiImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing SpiderImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing SunImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing TgaImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing TiffImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing WebPImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing WmfImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing XbmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing XpmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
|
||||||
17%|█▋ | 2/12 [00:47<03:55, 23.59s/it]
|
17%|█▋ | 2/12 [00:47<03:55, 23.59s/it]
|
||||||
25%|██▌ | 3/12 [01:10<03:31, 23.49s/it]
|
25%|██▌ | 3/12 [01:10<03:31, 23.49s/it]
|
||||||
33%|███▎ | 4/12 [01:34<03:07, 23.44s/it]
|
33%|███▎ | 4/12 [01:34<03:07, 23.44s/it]
|
||||||
@@ -144,6 +72,6 @@ DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
|||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
>>> Step 7: generating actions ...
|
>>> Step 7: generating actions ...
|
||||||
>>> Step 7: interacting with world model ...
|
>>> Step 7: interacting with world model ...
|
||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
>>> Step 8: generating actions ...
|
>>> Step 8: generating actions ...
|
||||||
>>> Step 8: interacting with world model ...
|
>>> Step 8: interacting with world model ...
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
{
|
{
|
||||||
"gt_video": "unitree_z1_stackbox/case4/unitree_z1_stackbox_case4.mp4",
|
"gt_video": "unitree_z1_stackbox/case4/unitree_z1_stackbox_case4.mp4",
|
||||||
"pred_video": "unitree_z1_stackbox/case4/output/inference/unitree_z1_stackbox_case4_amd.mp4",
|
"pred_video": "unitree_z1_stackbox/case4/output/inference/35_full_fs4.mp4",
|
||||||
"psnr": 25.32928948331741
|
"psnr": 47.18724378194084
|
||||||
}
|
}
|
||||||
@@ -2,7 +2,7 @@ res_dir="unitree_z1_stackbox/case4"
|
|||||||
dataset="unitree_z1_stackbox"
|
dataset="unitree_z1_stackbox"
|
||||||
|
|
||||||
{
|
{
|
||||||
time CUDA_VISIBLE_DEVICES=7 python3 scripts/evaluation/world_model_interaction.py \
|
time CUDA_VISIBLE_DEVICES=0 python3 scripts/evaluation/world_model_interaction.py \
|
||||||
--seed 123 \
|
--seed 123 \
|
||||||
--ckpt_path ckpts/unifolm_wma_dual.ckpt \
|
--ckpt_path ckpts/unifolm_wma_dual.ckpt \
|
||||||
--config configs/inference/world_model_interaction.yaml \
|
--config configs/inference/world_model_interaction.yaml \
|
||||||
@@ -20,5 +20,6 @@ dataset="unitree_z1_stackbox"
|
|||||||
--n_iter 12 \
|
--n_iter 12 \
|
||||||
--timestep_spacing 'uniform_trailing' \
|
--timestep_spacing 'uniform_trailing' \
|
||||||
--guidance_rescale 0.7 \
|
--guidance_rescale 0.7 \
|
||||||
--perframe_ae
|
--perframe_ae \
|
||||||
|
--fast_policy_no_decode
|
||||||
} 2>&1 | tee "${res_dir}/output.log"
|
} 2>&1 | tee "${res_dir}/output.log"
|
||||||
|
|||||||
Reference in New Issue
Block a user