1. einsum('b i d, b j d -> b i j') → torch.bmm(q, k.transpose(-1,-2)) — 直接映射 rocBLAS batched GEMM

2. baddbmm 把 scale 融合进 GEMM,少一次 kernel launch
3. 第二个 einsum 同理换torch.bm
每一轮加速1到两秒
This commit is contained in:
2026-02-08 18:54:48 +00:00
parent 7338cc384a
commit a2cd34dd51
5 changed files with 994 additions and 37 deletions

View File

@@ -1,14 +1,16 @@
2026-02-08 05:06:45.806187: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-02-08 05:06:45.809295: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-08 05:06:45.840950: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2026-02-08 05:06:45.840981: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2026-02-08 05:06:45.842814: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2026-02-08 05:06:45.851049: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-08 05:06:45.851316: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
__import__("pkg_resources").declare_namespace(__name__)
2026-02-08 18:28:48.960238: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-02-08 18:28:48.963331: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-08 18:28:48.995688: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2026-02-08 18:28:48.995732: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2026-02-08 18:28:48.997547: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2026-02-08 18:28:49.005673: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-08 18:28:49.005948: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2026-02-08 05:06:47.225477: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
2026-02-08 18:28:50.009660: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
[rank: 0] Global seed set to 123
/mnt/ASC1637/miniconda3/envs/unifolm-wma/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
@@ -18,15 +20,27 @@ INFO:root:Loaded ViT-H-14 model config.
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
/mnt/ASC1637/miniconda3/envs/unifolm-wma/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
checkpoint = torch.load(checkpoint_path, map_location=map_location)
INFO:root:Loaded ViT-H-14 model config.
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:86: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:198: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
state_dict = torch.load(ckpt, map_location="cpu")
>>> model checkpoint loaded.
>>> Load pre-trained model ...
>>> Applying precision settings:
- Diffusion dtype: bf16
- Projector mode: bf16_full
- Encoder mode: bf16_full
- VAE dtype: fp32
✓ Diffusion model weights converted to bfloat16
✓ Projectors converted to bfloat16
✓ Encoders converted to bfloat16
✓ VAE kept in fp32 for best quality
⚠ Found 849 fp32 params, converting to bf16
✓ All parameters converted to bfloat16
✓ torch.compile: 3 ResBlocks in output_blocks[5, 8, 9]
INFO:root:***** Configing Data *****
>>> unitree_z1_stackbox: 1 data samples loaded.
>>> unitree_z1_stackbox: data stats loaded.
@@ -49,11 +63,11 @@ DEBUG:h5py._conv:Creating converter from 3 to 5
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
0%| | 0/11 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
proj = linear(q, w, b)
proj = linear(q, w, b)
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Flash attention support on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:225.)
attn_output = scaled_dot_product_attention(
attn_output = scaled_dot_product_attention(
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)
attn_output = scaled_dot_product_attention(
>>> Step 0: generating actions ...
@@ -106,7 +120,7 @@ DEBUG:PIL.Image:Importing WmfImagePlugin
DEBUG:PIL.Image:Importing WmfImagePlugin
DEBUG:PIL.Image:Importing XbmImagePlugin
DEBUG:PIL.Image:Importing XpmImagePlugin
DEBUG:PIL.Image:Importing XVThumbImagePlugin
DEBUG:PIL.Image:Importing XVThumbImagePlugin
9%|▉ | 1/11 [01:14<12:29, 74.95s/it]
18%|█▊ | 2/11 [02:23<10:40, 71.18s/it]
@@ -139,6 +153,6 @@ DEBUG:PIL.Image:Importing XVThumbImagePlugin
>>> Step 6: generating actions ...
>>> Step 6: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 7: generating actions ...
>>> Step 7: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 7: generating actions ...
>>> Step 7: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>