- state_unet 放到一个独立的 CUDA stream 上执行
- action_unet 在默认 stream 上同时执行 - 用 wait_stream 确保两者都完成后再返回 两个 1D UNet 输入完全独立,共享的 hs_a 和 context_action 都是只读的。GPU 利用率只有 ~31%,小张量 kernel 不会打满 GPU,两个 stream 可以真正并行。
This commit is contained in:
@@ -5,7 +5,11 @@
|
|||||||
"Bash(mamba env:*)",
|
"Bash(mamba env:*)",
|
||||||
"Bash(micromamba env list:*)",
|
"Bash(micromamba env list:*)",
|
||||||
"Bash(echo:*)",
|
"Bash(echo:*)",
|
||||||
"Bash(git show:*)"
|
"Bash(git show:*)",
|
||||||
|
"Bash(nvidia-smi:*)",
|
||||||
|
"Bash(conda activate unifolm-wma)",
|
||||||
|
"Bash(conda info:*)",
|
||||||
|
"Bash(direnv allow:*)"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
2
.envrc
Normal file
2
.envrc
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
eval "$(conda shell.bash hook 2>/dev/null)"
|
||||||
|
conda activate unifolm-wma
|
||||||
@@ -688,6 +688,8 @@ class WMAModel(nn.Module):
|
|||||||
# Context precomputation cache
|
# Context precomputation cache
|
||||||
self._ctx_cache_enabled = False
|
self._ctx_cache_enabled = False
|
||||||
self._ctx_cache = {}
|
self._ctx_cache = {}
|
||||||
|
# Reusable CUDA stream for parallel state_unet / action_unet
|
||||||
|
self._state_stream = torch.cuda.Stream()
|
||||||
|
|
||||||
def forward(self,
|
def forward(self,
|
||||||
x: Tensor,
|
x: Tensor,
|
||||||
@@ -842,15 +844,16 @@ class WMAModel(nn.Module):
|
|||||||
|
|
||||||
if not self.base_model_gen_only:
|
if not self.base_model_gen_only:
|
||||||
ba, _, _ = x_action.shape
|
ba, _, _ = x_action.shape
|
||||||
|
ts_state = timesteps[:ba] if b > 1 else timesteps
|
||||||
|
# Run action_unet and state_unet in parallel via CUDA streams
|
||||||
|
s_stream = self._state_stream
|
||||||
|
s_stream.wait_stream(torch.cuda.current_stream())
|
||||||
|
with torch.cuda.stream(s_stream):
|
||||||
|
s_y = self.state_unet(x_state, ts_state, hs_a,
|
||||||
|
context_action[:2], **kwargs)
|
||||||
a_y = self.action_unet(x_action, timesteps[:ba], hs_a,
|
a_y = self.action_unet(x_action, timesteps[:ba], hs_a,
|
||||||
context_action[:2], **kwargs)
|
context_action[:2], **kwargs)
|
||||||
# Predict state
|
torch.cuda.current_stream().wait_stream(s_stream)
|
||||||
if b > 1:
|
|
||||||
s_y = self.state_unet(x_state, timesteps[:ba], hs_a,
|
|
||||||
context_action[:2], **kwargs)
|
|
||||||
else:
|
|
||||||
s_y = self.state_unet(x_state, timesteps, hs_a,
|
|
||||||
context_action[:2], **kwargs)
|
|
||||||
else:
|
else:
|
||||||
a_y = torch.zeros_like(x_action)
|
a_y = torch.zeros_like(x_action)
|
||||||
s_y = torch.zeros_like(x_state)
|
s_y = torch.zeros_like(x_state)
|
||||||
|
|||||||
@@ -1,10 +1,10 @@
|
|||||||
2026-02-10 19:43:34.679819: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
2026-02-10 21:29:54.531726: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
||||||
2026-02-10 19:43:34.729245: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
2026-02-10 21:29:54.581091: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
||||||
2026-02-10 19:43:34.729298: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
2026-02-10 21:29:54.581133: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
||||||
2026-02-10 19:43:34.730600: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
2026-02-10 21:29:54.582445: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
||||||
2026-02-10 19:43:34.738078: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
2026-02-10 21:29:54.589984: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
||||||
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
||||||
2026-02-10 19:43:35.659490: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
2026-02-10 21:29:55.504855: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
||||||
Global seed set to 123
|
Global seed set to 123
|
||||||
>>> Loading prepared model from ckpts/unifolm_wma_dual.ckpt.prepared.pt ...
|
>>> Loading prepared model from ckpts/unifolm_wma_dual.ckpt.prepared.pt ...
|
||||||
>>> Prepared model loaded.
|
>>> Prepared model loaded.
|
||||||
@@ -30,7 +30,7 @@ DEBUG:h5py._conv:Creating converter from 3 to 5
|
|||||||
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
|
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
|
||||||
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
|
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
|
||||||
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
||||||
|
|
||||||
0%| | 0/11 [00:00<?, ?it/s]
|
0%| | 0/11 [00:00<?, ?it/s]
|
||||||
9%|▉ | 1/11 [00:33<05:36, 33.61s/it]>>> Step 0: generating actions ...
|
9%|▉ | 1/11 [00:33<05:36, 33.61s/it]>>> Step 0: generating actions ...
|
||||||
>>> Step 0: interacting with world model ...
|
>>> Step 0: interacting with world model ...
|
||||||
@@ -83,7 +83,7 @@ DEBUG:PIL.Image:Importing WmfImagePlugin
|
|||||||
DEBUG:PIL.Image:Importing WebPImagePlugin
|
DEBUG:PIL.Image:Importing WebPImagePlugin
|
||||||
DEBUG:PIL.Image:Importing WmfImagePlugin
|
DEBUG:PIL.Image:Importing WmfImagePlugin
|
||||||
DEBUG:PIL.Image:Importing XbmImagePlugin
|
DEBUG:PIL.Image:Importing XbmImagePlugin
|
||||||
DEBUG:PIL.Image:Importing XpmImagePlugin
|
DEBUG:PIL.Image:Importing XpmImagePlugin
|
||||||
DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
||||||
|
|
||||||
18%|█▊ | 2/11 [01:07<05:04, 33.81s/it]
|
18%|█▊ | 2/11 [01:07<05:04, 33.81s/it]
|
||||||
@@ -114,6 +114,6 @@ DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
|||||||
>>> Step 6: generating actions ...
|
>>> Step 6: generating actions ...
|
||||||
>>> Step 6: interacting with world model ...
|
>>> Step 6: interacting with world model ...
|
||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
>>> Step 7: generating actions ...
|
>>> Step 7: generating actions ...
|
||||||
>>> Step 7: interacting with world model ...
|
>>> Step 7: interacting with world model ...
|
||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
|||||||
Reference in New Issue
Block a user