From dcbcb2c377db625da74d149709adb27297146e31 Mon Sep 17 00:00:00 2001 From: qhy <2728290997@qq.com> Date: Tue, 10 Feb 2026 21:41:48 +0800 Subject: [PATCH] =?UTF-8?q?-=20state=5Funet=20=E6=94=BE=E5=88=B0=E4=B8=80?= =?UTF-8?q?=E4=B8=AA=E7=8B=AC=E7=AB=8B=E7=9A=84=20CUDA=20stream=20?= =?UTF-8?q?=E4=B8=8A=E6=89=A7=E8=A1=8C=20=20=20-=20action=5Funet=20?= =?UTF-8?q?=E5=9C=A8=E9=BB=98=E8=AE=A4=20stream=20=E4=B8=8A=E5=90=8C?= =?UTF-8?q?=E6=97=B6=E6=89=A7=E8=A1=8C=20=20=20-=20=E7=94=A8=20wait=5Fstre?= =?UTF-8?q?am=20=E7=A1=AE=E4=BF=9D=E4=B8=A4=E8=80=85=E9=83=BD=E5=AE=8C?= =?UTF-8?q?=E6=88=90=E5=90=8E=E5=86=8D=E8=BF=94=E5=9B=9E=20=E4=B8=A4?= =?UTF-8?q?=E4=B8=AA=201D=20UNet=20=E8=BE=93=E5=85=A5=E5=AE=8C=E5=85=A8?= =?UTF-8?q?=E7=8B=AC=E7=AB=8B=EF=BC=8C=E5=85=B1=E4=BA=AB=E7=9A=84=20hs=5Fa?= =?UTF-8?q?=20=E5=92=8C=20context=5Faction=20=E9=83=BD=E6=98=AF=E5=8F=AA?= =?UTF-8?q?=E8=AF=BB=E7=9A=84=E3=80=82GPU=20=E5=88=A9=E7=94=A8=E7=8E=87?= =?UTF-8?q?=E5=8F=AA=E6=9C=89=20~31%=EF=BC=8C=E5=B0=8F=E5=BC=A0=E9=87=8F?= =?UTF-8?q?=20kernel=20=E4=B8=8D=E4=BC=9A=E6=89=93=E6=BB=A1=20GPU=EF=BC=8C?= =?UTF-8?q?=E4=B8=A4=E4=B8=AA=20stream=20=E5=8F=AF=E4=BB=A5=E7=9C=9F?= =?UTF-8?q?=E6=AD=A3=E5=B9=B6=E8=A1=8C=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .claude/settings.local.json | 6 ++++- .envrc | 2 ++ src/unifolm_wma/modules/networks/wma_model.py | 17 ++++++++------ .../case1/output.log | 22 +++++++++---------- 4 files changed, 28 insertions(+), 19 deletions(-) create mode 100644 .envrc diff --git a/.claude/settings.local.json b/.claude/settings.local.json index 3bfcae1..6a54c35 100644 --- a/.claude/settings.local.json +++ b/.claude/settings.local.json @@ -5,7 +5,11 @@ "Bash(mamba env:*)", "Bash(micromamba env list:*)", "Bash(echo:*)", - "Bash(git show:*)" + "Bash(git show:*)", + "Bash(nvidia-smi:*)", + "Bash(conda activate unifolm-wma)", + "Bash(conda info:*)", + "Bash(direnv allow:*)" ] } } diff --git a/.envrc b/.envrc new file mode 100644 index 0000000..850aead --- /dev/null +++ b/.envrc @@ -0,0 +1,2 @@ +eval "$(conda shell.bash hook 2>/dev/null)" +conda activate unifolm-wma diff --git a/src/unifolm_wma/modules/networks/wma_model.py b/src/unifolm_wma/modules/networks/wma_model.py index 66a08d2..3244110 100644 --- a/src/unifolm_wma/modules/networks/wma_model.py +++ b/src/unifolm_wma/modules/networks/wma_model.py @@ -688,6 +688,8 @@ class WMAModel(nn.Module): # Context precomputation cache self._ctx_cache_enabled = False self._ctx_cache = {} + # Reusable CUDA stream for parallel state_unet / action_unet + self._state_stream = torch.cuda.Stream() def forward(self, x: Tensor, @@ -842,15 +844,16 @@ class WMAModel(nn.Module): if not self.base_model_gen_only: ba, _, _ = x_action.shape + ts_state = timesteps[:ba] if b > 1 else timesteps + # Run action_unet and state_unet in parallel via CUDA streams + s_stream = self._state_stream + s_stream.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(s_stream): + s_y = self.state_unet(x_state, ts_state, hs_a, + context_action[:2], **kwargs) a_y = self.action_unet(x_action, timesteps[:ba], hs_a, context_action[:2], **kwargs) - # Predict state - if b > 1: - s_y = self.state_unet(x_state, timesteps[:ba], hs_a, - context_action[:2], **kwargs) - else: - s_y = self.state_unet(x_state, timesteps, hs_a, - context_action[:2], **kwargs) + torch.cuda.current_stream().wait_stream(s_stream) else: a_y = torch.zeros_like(x_action) s_y = torch.zeros_like(x_state) diff --git a/unitree_z1_dual_arm_stackbox_v2/case1/output.log b/unitree_z1_dual_arm_stackbox_v2/case1/output.log index 47c738b..7565a8c 100644 --- a/unitree_z1_dual_arm_stackbox_v2/case1/output.log +++ b/unitree_z1_dual_arm_stackbox_v2/case1/output.log @@ -1,10 +1,10 @@ -2026-02-10 19:43:34.679819: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`. -2026-02-10 19:43:34.729245: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered -2026-02-10 19:43:34.729298: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered -2026-02-10 19:43:34.730600: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered -2026-02-10 19:43:34.738078: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations. +2026-02-10 21:29:54.531726: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`. +2026-02-10 21:29:54.581091: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered +2026-02-10 21:29:54.581133: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered +2026-02-10 21:29:54.582445: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered +2026-02-10 21:29:54.589984: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations. To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags. -2026-02-10 19:43:35.659490: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT +2026-02-10 21:29:55.504855: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT Global seed set to 123 >>> Loading prepared model from ckpts/unifolm_wma_dual.ckpt.prepared.pt ... >>> Prepared model loaded. @@ -30,7 +30,7 @@ DEBUG:h5py._conv:Creating converter from 3 to 5 DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13 DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9 DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096 - 0%| | 0/11 [00:00>> Step 0: generating actions ... + 0%| | 0/11 [00:00>> Step 0: generating actions ... >>> Step 0: interacting with world model ... >>>>>>>>>>>>>>>>>>>>>>>> >>> Step 1: generating actions ... @@ -83,7 +83,7 @@ DEBUG:PIL.Image:Importing WmfImagePlugin DEBUG:PIL.Image:Importing XbmImagePlugin DEBUG:PIL.Image:Importing XpmImagePlugin DEBUG:PIL.Image:Importing XVThumbImagePlugin - 18%|█▊ | 2/11 [01:08<05:06, 34.03s/it] 27%|██▋ | 3/11 [01:42<04:34, 34.28s/it] 36%|███▋ | 4/11 [02:17<04:01, 34.45s/it] 45%|████▌ | 5/11 [02:51<03:26, 34.48s/it] 55%|█████▍ | 6/11 [03:26<02:52, 34.50s/it] 64%|██████▎ | 7/11 [04:00<02:18, 34.51s/it] 73%|███████▎ | 8/11 [04:35<01:43, 34.53s/it] 82%|████████▏ | 9/11 [05:10<01:09, 34.56s/it] 91%|█████████ | 10/11 [05:44<00:34, 34.53s/it] 100%|██████████| 11/11 [06:18<00:00, 34.50s/it] 100%|██████████| 11/11 [06:18<00:00, 34.45s/it] + 18%|█▊ | 2/11 [01:07<05:04, 33.81s/it] 27%|██▋ | 3/11 [01:41<04:32, 34.02s/it] 36%|███▋ | 4/11 [02:16<03:58, 34.13s/it] 45%|████▌ | 5/11 [02:50<03:24, 34.12s/it] 55%|█████▍ | 6/11 [03:24<02:50, 34.11s/it] 64%|██████▎ | 7/11 [03:58<02:16, 34.10s/it] 73%|███████▎ | 8/11 [04:32<01:42, 34.11s/it] 82%|████████▏ | 9/11 [05:06<01:08, 34.13s/it] 91%|█████████ | 10/11 [05:40<00:34, 34.15s/it] 100%|██████████| 11/11 [06:14<00:00, 34.12s/it] 100%|██████████| 11/11 [06:14<00:00, 34.09s/it] >>> Step 1: interacting with world model ... >>>>>>>>>>>>>>>>>>>>>>>> >>> Step 2: generating actions ... @@ -114,6 +114,6 @@ DEBUG:PIL.Image:Importing XVThumbImagePlugin >>> Step 10: interacting with world model ... >>>>>>>>>>>>>>>>>>>>>>>> -real 6m56.631s -user 5m36.951s -sys 2m10.073s +real 6m50.156s +user 6m25.849s +sys 1m14.933s