- state_unet 放到一个独立的 CUDA stream 上执行
- action_unet 在默认 stream 上同时执行 - 用 wait_stream 确保两者都完成后再返回 两个 1D UNet 输入完全独立,共享的 hs_a 和 context_action 都是只读的。GPU 利用率只有 ~31%,小张量 kernel 不会打满 GPU,两个 stream 可以真正并行。
This commit is contained in:
@@ -5,7 +5,11 @@
|
||||
"Bash(mamba env:*)",
|
||||
"Bash(micromamba env list:*)",
|
||||
"Bash(echo:*)",
|
||||
"Bash(git show:*)"
|
||||
"Bash(git show:*)",
|
||||
"Bash(nvidia-smi:*)",
|
||||
"Bash(conda activate unifolm-wma)",
|
||||
"Bash(conda info:*)",
|
||||
"Bash(direnv allow:*)"
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
2
.envrc
Normal file
2
.envrc
Normal file
@@ -0,0 +1,2 @@
|
||||
eval "$(conda shell.bash hook 2>/dev/null)"
|
||||
conda activate unifolm-wma
|
||||
@@ -688,6 +688,8 @@ class WMAModel(nn.Module):
|
||||
# Context precomputation cache
|
||||
self._ctx_cache_enabled = False
|
||||
self._ctx_cache = {}
|
||||
# Reusable CUDA stream for parallel state_unet / action_unet
|
||||
self._state_stream = torch.cuda.Stream()
|
||||
|
||||
def forward(self,
|
||||
x: Tensor,
|
||||
@@ -842,15 +844,16 @@ class WMAModel(nn.Module):
|
||||
|
||||
if not self.base_model_gen_only:
|
||||
ba, _, _ = x_action.shape
|
||||
ts_state = timesteps[:ba] if b > 1 else timesteps
|
||||
# Run action_unet and state_unet in parallel via CUDA streams
|
||||
s_stream = self._state_stream
|
||||
s_stream.wait_stream(torch.cuda.current_stream())
|
||||
with torch.cuda.stream(s_stream):
|
||||
s_y = self.state_unet(x_state, ts_state, hs_a,
|
||||
context_action[:2], **kwargs)
|
||||
a_y = self.action_unet(x_action, timesteps[:ba], hs_a,
|
||||
context_action[:2], **kwargs)
|
||||
# Predict state
|
||||
if b > 1:
|
||||
s_y = self.state_unet(x_state, timesteps[:ba], hs_a,
|
||||
context_action[:2], **kwargs)
|
||||
else:
|
||||
s_y = self.state_unet(x_state, timesteps, hs_a,
|
||||
context_action[:2], **kwargs)
|
||||
torch.cuda.current_stream().wait_stream(s_stream)
|
||||
else:
|
||||
a_y = torch.zeros_like(x_action)
|
||||
s_y = torch.zeros_like(x_state)
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
2026-02-10 19:43:34.679819: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
||||
2026-02-10 19:43:34.729245: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
||||
2026-02-10 19:43:34.729298: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
||||
2026-02-10 19:43:34.730600: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
||||
2026-02-10 19:43:34.738078: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
||||
2026-02-10 21:29:54.531726: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
||||
2026-02-10 21:29:54.581091: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
||||
2026-02-10 21:29:54.581133: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
||||
2026-02-10 21:29:54.582445: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
||||
2026-02-10 21:29:54.589984: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
||||
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
||||
2026-02-10 19:43:35.659490: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
||||
2026-02-10 21:29:55.504855: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
||||
Global seed set to 123
|
||||
>>> Loading prepared model from ckpts/unifolm_wma_dual.ckpt.prepared.pt ...
|
||||
>>> Prepared model loaded.
|
||||
@@ -30,7 +30,7 @@ DEBUG:h5py._conv:Creating converter from 3 to 5
|
||||
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
|
||||
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
|
||||
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
||||
|
||||
|
||||
0%| | 0/11 [00:00<?, ?it/s]
|
||||
9%|▉ | 1/11 [00:33<05:36, 33.61s/it]>>> Step 0: generating actions ...
|
||||
>>> Step 0: interacting with world model ...
|
||||
@@ -83,7 +83,7 @@ DEBUG:PIL.Image:Importing WmfImagePlugin
|
||||
DEBUG:PIL.Image:Importing WebPImagePlugin
|
||||
DEBUG:PIL.Image:Importing WmfImagePlugin
|
||||
DEBUG:PIL.Image:Importing XbmImagePlugin
|
||||
DEBUG:PIL.Image:Importing XpmImagePlugin
|
||||
DEBUG:PIL.Image:Importing XpmImagePlugin
|
||||
DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
||||
|
||||
18%|█▊ | 2/11 [01:07<05:04, 33.81s/it]
|
||||
@@ -114,6 +114,6 @@ DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
||||
>>> Step 6: generating actions ...
|
||||
>>> Step 6: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 7: generating actions ...
|
||||
>>> Step 7: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 7: generating actions ...
|
||||
>>> Step 7: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
|
||||
Reference in New Issue
Block a user