From 6dca3696d86003df3ab128438156b0b5cc7cba14 Mon Sep 17 00:00:00 2001 From: olivame Date: Mon, 9 Feb 2026 17:42:47 +0000 Subject: [PATCH] =?UTF-8?q?=E5=AE=9E=E7=8E=B0=E4=BA=86Context=20=E9=A2=84?= =?UTF-8?q?=E8=AE=A1=E7=AE=97=E5=92=8C=E7=BC=93=E5=AD=98=E5=8A=9F=E8=83=BD?= =?UTF-8?q?=EF=BC=8C=E6=8F=90=E5=8D=87=E4=BA=86=E9=87=87=E6=A0=B7=E6=95=88?= =?UTF-8?q?=E7=8E=87=E3=80=82=20psnr=E4=B8=8D=E4=B8=8B=E9=99=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../diffusion_head/conditional_unet1d.py | 26 ++-- src/unifolm_wma/models/samplers/ddim.py | 3 + src/unifolm_wma/modules/networks/wma_model.py | 135 +++++++++++------- .../case1/output.log | 24 ++-- 4 files changed, 119 insertions(+), 69 deletions(-) diff --git a/src/unifolm_wma/models/diffusion_head/conditional_unet1d.py b/src/unifolm_wma/models/diffusion_head/conditional_unet1d.py index 12378a1..f63510f 100644 --- a/src/unifolm_wma/models/diffusion_head/conditional_unet1d.py +++ b/src/unifolm_wma/models/diffusion_head/conditional_unet1d.py @@ -501,6 +501,10 @@ class ConditionalUnet1D(nn.Module): self.last_frame_only = last_frame_only self.horizon = horizon + # Context precomputation cache + self._global_cond_cache_enabled = False + self._global_cond_cache = {} + def forward(self, sample: torch.Tensor, timestep: Union[torch.Tensor, float, int], @@ -530,14 +534,20 @@ class ConditionalUnet1D(nn.Module): B, T, D = sample.shape if self.use_linear_act_proj: sample = self.proj_in_action(sample.unsqueeze(-1)) - global_cond = self.obs_encoder(cond) - global_cond = rearrange(global_cond, - '(b t) d -> b 1 (t d)', - b=B, - t=self.n_obs_steps) - global_cond = repeat(global_cond, - 'b c d -> b (repeat c) d', - repeat=T) + _gc_key = (cond['image'].data_ptr(), cond['agent_pos'].data_ptr()) + if self._global_cond_cache_enabled and _gc_key in self._global_cond_cache: + global_cond = self._global_cond_cache[_gc_key] + else: + global_cond = self.obs_encoder(cond) + global_cond = rearrange(global_cond, + '(b t) d -> b 1 (t d)', + b=B, + t=self.n_obs_steps) + global_cond = repeat(global_cond, + 'b c d -> b (repeat c) d', + repeat=T) + if self._global_cond_cache_enabled: + self._global_cond_cache[_gc_key] = global_cond else: sample = einops.rearrange(sample, 'b h t -> b t h') sample = self.proj_in_horizon(sample) diff --git a/src/unifolm_wma/models/samplers/ddim.py b/src/unifolm_wma/models/samplers/ddim.py index fe50ea0..698732e 100644 --- a/src/unifolm_wma/models/samplers/ddim.py +++ b/src/unifolm_wma/models/samplers/ddim.py @@ -7,6 +7,7 @@ from unifolm_wma.utils.common import noise_like from unifolm_wma.utils.common import extract_into_tensor from tqdm import tqdm from unifolm_wma.modules.attention import enable_cross_attn_kv_cache, disable_cross_attn_kv_cache +from unifolm_wma.modules.networks.wma_model import enable_ctx_cache, disable_ctx_cache class DDIMSampler(object): @@ -245,6 +246,7 @@ class DDIMSampler(object): dp_ddim_scheduler_state.set_timesteps(len(timesteps)) ts = torch.empty((b, ), device=device, dtype=torch.long) enable_cross_attn_kv_cache(self.model) + enable_ctx_cache(self.model) try: for i, step in enumerate(iterator): index = total_steps - i - 1 @@ -305,6 +307,7 @@ class DDIMSampler(object): intermediates['x_inter_state'].append(state) finally: disable_cross_attn_kv_cache(self.model) + disable_ctx_cache(self.model) return img, action, state, intermediates diff --git a/src/unifolm_wma/modules/networks/wma_model.py b/src/unifolm_wma/modules/networks/wma_model.py index 14c5478..cc7f4ba 100644 --- a/src/unifolm_wma/modules/networks/wma_model.py +++ b/src/unifolm_wma/modules/networks/wma_model.py @@ -685,6 +685,10 @@ class WMAModel(nn.Module): self.action_token_projector = instantiate_from_config( stem_process_config) + # Context precomputation cache + self._ctx_cache_enabled = False + self._ctx_cache = {} + def forward(self, x: Tensor, x_action: Tensor, @@ -720,58 +724,64 @@ class WMAModel(nn.Module): repeat_only=False).type(x.dtype) emb = self.time_embed(t_emb) - bt, l_context, _ = context.shape - if self.base_model_gen_only: - assert l_context == 77 + self.n_obs_steps * 16, ">>> ERROR Context dim 1 ..." ## NOTE HANDCODE + _ctx_key = context.data_ptr() + if self._ctx_cache_enabled and _ctx_key in self._ctx_cache: + context = self._ctx_cache[_ctx_key] else: - if l_context == self.n_obs_steps + 77 + t * 16: - context_agent_state = context[:, :self.n_obs_steps] - context_text = context[:, self.n_obs_steps:self.n_obs_steps + - 77, :] - context_img = context[:, self.n_obs_steps + 77:, :] - context_agent_state = context_agent_state.repeat_interleave( - repeats=t, dim=0) - context_text = context_text.repeat_interleave(repeats=t, dim=0) - context_img = rearrange(context_img, - 'b (t l) c -> (b t) l c', - t=t) - context = torch.cat( - [context_agent_state, context_text, context_img], dim=1) - elif l_context == self.n_obs_steps + 16 + 77 + t * 16: - context_agent_state = context[:, :self.n_obs_steps] - context_agent_action = context[:, self. - n_obs_steps:self.n_obs_steps + - 16, :] - context_agent_action = rearrange( - context_agent_action.unsqueeze(2), 'b t l d -> (b t) l d') - context_agent_action = self.action_token_projector( - context_agent_action) - context_agent_action = rearrange(context_agent_action, - '(b o) l d -> b o l d', - o=t) - context_agent_action = rearrange(context_agent_action, - 'b o (t l) d -> b o t l d', - t=t) - context_agent_action = context_agent_action.permute( - 0, 2, 1, 3, 4) - context_agent_action = rearrange(context_agent_action, - 'b t o l d -> (b t) (o l) d') + bt, l_context, _ = context.shape + if self.base_model_gen_only: + assert l_context == 77 + self.n_obs_steps * 16, ">>> ERROR Context dim 1 ..." ## NOTE HANDCODE + else: + if l_context == self.n_obs_steps + 77 + t * 16: + context_agent_state = context[:, :self.n_obs_steps] + context_text = context[:, self.n_obs_steps:self.n_obs_steps + + 77, :] + context_img = context[:, self.n_obs_steps + 77:, :] + context_agent_state = context_agent_state.repeat_interleave( + repeats=t, dim=0) + context_text = context_text.repeat_interleave(repeats=t, dim=0) + context_img = rearrange(context_img, + 'b (t l) c -> (b t) l c', + t=t) + context = torch.cat( + [context_agent_state, context_text, context_img], dim=1) + elif l_context == self.n_obs_steps + 16 + 77 + t * 16: + context_agent_state = context[:, :self.n_obs_steps] + context_agent_action = context[:, self. + n_obs_steps:self.n_obs_steps + + 16, :] + context_agent_action = rearrange( + context_agent_action.unsqueeze(2), 'b t l d -> (b t) l d') + context_agent_action = self.action_token_projector( + context_agent_action) + context_agent_action = rearrange(context_agent_action, + '(b o) l d -> b o l d', + o=t) + context_agent_action = rearrange(context_agent_action, + 'b o (t l) d -> b o t l d', + t=t) + context_agent_action = context_agent_action.permute( + 0, 2, 1, 3, 4) + context_agent_action = rearrange(context_agent_action, + 'b t o l d -> (b t) (o l) d') - context_text = context[:, self.n_obs_steps + - 16:self.n_obs_steps + 16 + 77, :] - context_text = context_text.repeat_interleave(repeats=t, dim=0) + context_text = context[:, self.n_obs_steps + + 16:self.n_obs_steps + 16 + 77, :] + context_text = context_text.repeat_interleave(repeats=t, dim=0) - context_img = context[:, self.n_obs_steps + 16 + 77:, :] - context_img = rearrange(context_img, - 'b (t l) c -> (b t) l c', - t=t) - context_agent_state = context_agent_state.repeat_interleave( - repeats=t, dim=0) - context = torch.cat([ - context_agent_state, context_agent_action, context_text, - context_img - ], - dim=1) + context_img = context[:, self.n_obs_steps + 16 + 77:, :] + context_img = rearrange(context_img, + 'b (t l) c -> (b t) l c', + t=t) + context_agent_state = context_agent_state.repeat_interleave( + repeats=t, dim=0) + context = torch.cat([ + context_agent_state, context_agent_action, context_text, + context_img + ], + dim=1) + if self._ctx_cache_enabled: + self._ctx_cache[_ctx_key] = context emb = emb.repeat_interleave(repeats=t, dim=0) @@ -846,3 +856,30 @@ class WMAModel(nn.Module): s_y = torch.zeros_like(x_state) return y, a_y, s_y + + +def enable_ctx_cache(model): + """Enable context precomputation cache on WMAModel and its action/state UNets.""" + for m in model.modules(): + if isinstance(m, WMAModel): + m._ctx_cache_enabled = True + m._ctx_cache = {} + # conditional_unet1d cache + from unifolm_wma.models.diffusion_head.conditional_unet1d import ConditionalUnet1D + for m in model.modules(): + if isinstance(m, ConditionalUnet1D): + m._global_cond_cache_enabled = True + m._global_cond_cache = {} + + +def disable_ctx_cache(model): + """Disable and clear context precomputation cache.""" + for m in model.modules(): + if isinstance(m, WMAModel): + m._ctx_cache_enabled = False + m._ctx_cache = {} + from unifolm_wma.models.diffusion_head.conditional_unet1d import ConditionalUnet1D + for m in model.modules(): + if isinstance(m, ConditionalUnet1D): + m._global_cond_cache_enabled = False + m._global_cond_cache = {} diff --git a/unitree_z1_dual_arm_cleanup_pencils/case1/output.log b/unitree_z1_dual_arm_cleanup_pencils/case1/output.log index 4286993..69aa6e7 100644 --- a/unitree_z1_dual_arm_cleanup_pencils/case1/output.log +++ b/unitree_z1_dual_arm_cleanup_pencils/case1/output.log @@ -1,14 +1,14 @@ /mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81. __import__("pkg_resources").declare_namespace(__name__) -2026-02-09 16:53:59.556813: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`. -2026-02-09 16:53:59.559892: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used. -2026-02-09 16:53:59.591414: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered -2026-02-09 16:53:59.591446: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered -2026-02-09 16:53:59.593281: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered -2026-02-09 16:53:59.601486: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used. -2026-02-09 16:53:59.601838: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations. +2026-02-09 17:32:41.850068: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`. +2026-02-09 17:32:41.853132: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used. +2026-02-09 17:32:41.886058: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered +2026-02-09 17:32:41.886103: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered +2026-02-09 17:32:41.887979: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered +2026-02-09 17:32:41.896994: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used. +2026-02-09 17:32:41.897283: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations. To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags. -2026-02-09 16:54:00.228108: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT +2026-02-09 17:32:42.611394: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT [rank: 0] Global seed set to 123 /mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead. @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32) @@ -116,7 +116,7 @@ DEBUG:PIL.Image:Importing WmfImagePlugin DEBUG:PIL.Image:Importing XbmImagePlugin DEBUG:PIL.Image:Importing XpmImagePlugin DEBUG:PIL.Image:Importing XVThumbImagePlugin - 12%|█▎ | 1/8 [01:09<08:08, 69.72s/it] 25%|██▌ | 2/8 [02:15<06:45, 67.61s/it] 38%|███▊ | 3/8 [03:21<05:34, 66.92s/it] 50%|█████ | 4/8 [04:28<04:26, 66.60s/it] 62%|██████▎ | 5/8 [05:34<03:19, 66.44s/it] 75%|███████▌ | 6/8 [06:40<02:12, 66.32s/it] 88%|████████▊ | 7/8 [07:46<01:06, 66.25s/it] 100%|██████████| 8/8 [08:52<00:00, 66.23s/it] 100%|██████████| 8/8 [08:52<00:00, 66.57s/it] + 12%|█▎ | 1/8 [01:08<08:00, 68.62s/it] 25%|██▌ | 2/8 [02:13<06:38, 66.41s/it] 38%|███▊ | 3/8 [03:18<05:29, 65.84s/it] 50%|█████ | 4/8 [04:23<04:22, 65.55s/it] 62%|██████▎ | 5/8 [05:28<03:16, 65.36s/it] 75%|███████▌ | 6/8 [06:33<02:10, 65.25s/it] 88%|████████▊ | 7/8 [07:38<01:05, 65.14s/it] 100%|██████████| 8/8 [08:43<00:00, 65.10s/it] 100%|██████████| 8/8 [08:43<00:00, 65.47s/it] >>>>>>>>>>>>>>>>>>>>>>>> >>> Step 1: generating actions ... >>> Step 1: interacting with world model ... @@ -140,6 +140,6 @@ DEBUG:PIL.Image:Importing XVThumbImagePlugin >>> Step 7: interacting with world model ... >>>>>>>>>>>>>>>>>>>>>>>> -real 9m53.691s -user 11m23.200s -sys 0m42.702s +real 9m47.606s +user 8m5.267s +sys 1m7.101s