diff --git a/src/unifolm_wma/models/diffusion_head/conditional_unet1d.py b/src/unifolm_wma/models/diffusion_head/conditional_unet1d.py index 12378a1..f63510f 100644 --- a/src/unifolm_wma/models/diffusion_head/conditional_unet1d.py +++ b/src/unifolm_wma/models/diffusion_head/conditional_unet1d.py @@ -501,6 +501,10 @@ class ConditionalUnet1D(nn.Module): self.last_frame_only = last_frame_only self.horizon = horizon + # Context precomputation cache + self._global_cond_cache_enabled = False + self._global_cond_cache = {} + def forward(self, sample: torch.Tensor, timestep: Union[torch.Tensor, float, int], @@ -530,14 +534,20 @@ class ConditionalUnet1D(nn.Module): B, T, D = sample.shape if self.use_linear_act_proj: sample = self.proj_in_action(sample.unsqueeze(-1)) - global_cond = self.obs_encoder(cond) - global_cond = rearrange(global_cond, - '(b t) d -> b 1 (t d)', - b=B, - t=self.n_obs_steps) - global_cond = repeat(global_cond, - 'b c d -> b (repeat c) d', - repeat=T) + _gc_key = (cond['image'].data_ptr(), cond['agent_pos'].data_ptr()) + if self._global_cond_cache_enabled and _gc_key in self._global_cond_cache: + global_cond = self._global_cond_cache[_gc_key] + else: + global_cond = self.obs_encoder(cond) + global_cond = rearrange(global_cond, + '(b t) d -> b 1 (t d)', + b=B, + t=self.n_obs_steps) + global_cond = repeat(global_cond, + 'b c d -> b (repeat c) d', + repeat=T) + if self._global_cond_cache_enabled: + self._global_cond_cache[_gc_key] = global_cond else: sample = einops.rearrange(sample, 'b h t -> b t h') sample = self.proj_in_horizon(sample) diff --git a/src/unifolm_wma/models/samplers/ddim.py b/src/unifolm_wma/models/samplers/ddim.py index 7323531..e560545 100644 --- a/src/unifolm_wma/models/samplers/ddim.py +++ b/src/unifolm_wma/models/samplers/ddim.py @@ -7,6 +7,7 @@ from unifolm_wma.utils.common import noise_like from unifolm_wma.utils.common import extract_into_tensor from tqdm import tqdm from unifolm_wma.modules.attention import enable_cross_attn_kv_cache, disable_cross_attn_kv_cache +from unifolm_wma.modules.networks.wma_model import enable_ctx_cache, disable_ctx_cache class DDIMSampler(object): @@ -245,6 +246,7 @@ class DDIMSampler(object): dp_ddim_scheduler_state.set_timesteps(len(timesteps)) ts = torch.empty((b, ), device=device, dtype=torch.long) enable_cross_attn_kv_cache(self.model) + enable_ctx_cache(self.model) try: for i, step in enumerate(iterator): index = total_steps - i - 1 @@ -305,6 +307,7 @@ class DDIMSampler(object): intermediates['x_inter_state'].append(state) finally: disable_cross_attn_kv_cache(self.model) + disable_ctx_cache(self.model) return img, action, state, intermediates diff --git a/src/unifolm_wma/modules/networks/wma_model.py b/src/unifolm_wma/modules/networks/wma_model.py index e1b4838..66a08d2 100644 --- a/src/unifolm_wma/modules/networks/wma_model.py +++ b/src/unifolm_wma/modules/networks/wma_model.py @@ -685,6 +685,10 @@ class WMAModel(nn.Module): self.action_token_projector = instantiate_from_config( stem_process_config) + # Context precomputation cache + self._ctx_cache_enabled = False + self._ctx_cache = {} + def forward(self, x: Tensor, x_action: Tensor, @@ -720,58 +724,64 @@ class WMAModel(nn.Module): repeat_only=False).type(x.dtype) emb = self.time_embed(t_emb) - bt, l_context, _ = context.shape - if self.base_model_gen_only: - assert l_context == 77 + self.n_obs_steps * 16, ">>> ERROR Context dim 1 ..." ## NOTE HANDCODE + _ctx_key = context.data_ptr() + if self._ctx_cache_enabled and _ctx_key in self._ctx_cache: + context = self._ctx_cache[_ctx_key] else: - if l_context == self.n_obs_steps + 77 + t * 16: - context_agent_state = context[:, :self.n_obs_steps] - context_text = context[:, self.n_obs_steps:self.n_obs_steps + - 77, :] - context_img = context[:, self.n_obs_steps + 77:, :] - context_agent_state = context_agent_state.repeat_interleave( - repeats=t, dim=0) - context_text = context_text.repeat_interleave(repeats=t, dim=0) - context_img = rearrange(context_img, - 'b (t l) c -> (b t) l c', - t=t) - context = torch.cat( - [context_agent_state, context_text, context_img], dim=1) - elif l_context == self.n_obs_steps + 16 + 77 + t * 16: - context_agent_state = context[:, :self.n_obs_steps] - context_agent_action = context[:, self. - n_obs_steps:self.n_obs_steps + - 16, :] - context_agent_action = rearrange( - context_agent_action.unsqueeze(2), 'b t l d -> (b t) l d') - context_agent_action = self.action_token_projector( - context_agent_action) - context_agent_action = rearrange(context_agent_action, - '(b o) l d -> b o l d', - o=t) - context_agent_action = rearrange(context_agent_action, - 'b o (t l) d -> b o t l d', - t=t) - context_agent_action = context_agent_action.permute( - 0, 2, 1, 3, 4) - context_agent_action = rearrange(context_agent_action, - 'b t o l d -> (b t) (o l) d') + bt, l_context, _ = context.shape + if self.base_model_gen_only: + assert l_context == 77 + self.n_obs_steps * 16, ">>> ERROR Context dim 1 ..." ## NOTE HANDCODE + else: + if l_context == self.n_obs_steps + 77 + t * 16: + context_agent_state = context[:, :self.n_obs_steps] + context_text = context[:, self.n_obs_steps:self.n_obs_steps + + 77, :] + context_img = context[:, self.n_obs_steps + 77:, :] + context_agent_state = context_agent_state.repeat_interleave( + repeats=t, dim=0) + context_text = context_text.repeat_interleave(repeats=t, dim=0) + context_img = rearrange(context_img, + 'b (t l) c -> (b t) l c', + t=t) + context = torch.cat( + [context_agent_state, context_text, context_img], dim=1) + elif l_context == self.n_obs_steps + 16 + 77 + t * 16: + context_agent_state = context[:, :self.n_obs_steps] + context_agent_action = context[:, self. + n_obs_steps:self.n_obs_steps + + 16, :] + context_agent_action = rearrange( + context_agent_action.unsqueeze(2), 'b t l d -> (b t) l d') + context_agent_action = self.action_token_projector( + context_agent_action) + context_agent_action = rearrange(context_agent_action, + '(b o) l d -> b o l d', + o=t) + context_agent_action = rearrange(context_agent_action, + 'b o (t l) d -> b o t l d', + t=t) + context_agent_action = context_agent_action.permute( + 0, 2, 1, 3, 4) + context_agent_action = rearrange(context_agent_action, + 'b t o l d -> (b t) (o l) d') - context_text = context[:, self.n_obs_steps + - 16:self.n_obs_steps + 16 + 77, :] - context_text = context_text.repeat_interleave(repeats=t, dim=0) + context_text = context[:, self.n_obs_steps + + 16:self.n_obs_steps + 16 + 77, :] + context_text = context_text.repeat_interleave(repeats=t, dim=0) - context_img = context[:, self.n_obs_steps + 16 + 77:, :] - context_img = rearrange(context_img, - 'b (t l) c -> (b t) l c', - t=t) - context_agent_state = context_agent_state.repeat_interleave( - repeats=t, dim=0) - context = torch.cat([ - context_agent_state, context_agent_action, context_text, - context_img - ], - dim=1) + context_img = context[:, self.n_obs_steps + 16 + 77:, :] + context_img = rearrange(context_img, + 'b (t l) c -> (b t) l c', + t=t) + context_agent_state = context_agent_state.repeat_interleave( + repeats=t, dim=0) + context = torch.cat([ + context_agent_state, context_agent_action, context_text, + context_img + ], + dim=1) + if self._ctx_cache_enabled: + self._ctx_cache[_ctx_key] = context emb = emb.repeat_interleave(repeats=t, dim=0) @@ -846,3 +856,30 @@ class WMAModel(nn.Module): s_y = torch.zeros_like(x_state) return y, a_y, s_y + + +def enable_ctx_cache(model): + """Enable context precomputation cache on WMAModel and its action/state UNets.""" + for m in model.modules(): + if isinstance(m, WMAModel): + m._ctx_cache_enabled = True + m._ctx_cache = {} + # conditional_unet1d cache + from unifolm_wma.models.diffusion_head.conditional_unet1d import ConditionalUnet1D + for m in model.modules(): + if isinstance(m, ConditionalUnet1D): + m._global_cond_cache_enabled = True + m._global_cond_cache = {} + + +def disable_ctx_cache(model): + """Disable and clear context precomputation cache.""" + for m in model.modules(): + if isinstance(m, WMAModel): + m._ctx_cache_enabled = False + m._ctx_cache = {} + from unifolm_wma.models.diffusion_head.conditional_unet1d import ConditionalUnet1D + for m in model.modules(): + if isinstance(m, ConditionalUnet1D): + m._global_cond_cache_enabled = False + m._global_cond_cache = {} diff --git a/unitree_z1_dual_arm_stackbox_v2/case1/output.log b/unitree_z1_dual_arm_stackbox_v2/case1/output.log index 72b254d..6daeda5 100644 --- a/unitree_z1_dual_arm_stackbox_v2/case1/output.log +++ b/unitree_z1_dual_arm_stackbox_v2/case1/output.log @@ -1,10 +1,10 @@ -2026-02-10 17:25:35.484333: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`. -2026-02-10 17:25:35.533963: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered -2026-02-10 17:25:35.534009: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered -2026-02-10 17:25:35.535311: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered -2026-02-10 17:25:35.542814: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations. +2026-02-10 17:39:22.590654: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`. +2026-02-10 17:39:22.640645: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered +2026-02-10 17:39:22.640689: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered +2026-02-10 17:39:22.642010: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered +2026-02-10 17:39:22.649530: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations. To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags. -2026-02-10 17:25:36.471650: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT +2026-02-10 17:39:23.575804: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT Global seed set to 123 INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08 @@ -92,7 +92,7 @@ DEBUG:PIL.Image:Importing WmfImagePlugin DEBUG:PIL.Image:Importing XbmImagePlugin DEBUG:PIL.Image:Importing XpmImagePlugin DEBUG:PIL.Image:Importing XVThumbImagePlugin - 9%|▉ | 1/11 [00:36<06:07, 36.77s/it] 18%|█▊ | 2/11 [01:13<05:32, 36.92s/it] 27%|██▋ | 3/11 [01:51<04:57, 37.20s/it] 36%|███▋ | 4/11 [02:28<04:21, 37.35s/it] 45%|████▌ | 5/11 [03:06<03:44, 37.39s/it] 55%|█████▍ | 6/11 [03:43<03:06, 37.31s/it] 64%|██████▎ | 7/11 [04:20<02:29, 37.26s/it] 73%|███████▎ | 8/11 [04:57<01:51, 37.24s/it] 82%|████████▏ | 9/11 [05:35<01:14, 37.22s/it] 91%|█████████ | 10/11 [06:12<00:37, 37.23s/it] 100%|██████████| 11/11 [06:49<00:00, 37.23s/it] 100%|██████████| 11/11 [06:49<00:00, 37.23s/it] + 9%|▉ | 1/11 [00:35<05:55, 35.52s/it] 18%|█▊ | 2/11 [01:11<05:21, 35.73s/it] 27%|██▋ | 3/11 [01:47<04:48, 36.04s/it] 36%|███▋ | 4/11 [02:24<04:13, 36.19s/it] 45%|████▌ | 5/11 [03:00<03:37, 36.25s/it] 55%|█████▍ | 6/11 [03:36<03:00, 36.16s/it] 64%|██████▎ | 7/11 [04:12<02:24, 36.09s/it] 73%|███████▎ | 8/11 [04:48<01:48, 36.08s/it] 82%|████████▏ | 9/11 [05:24<01:12, 36.06s/it] 91%|█████████ | 10/11 [06:00<00:36, 36.07s/it] 100%|██████████| 11/11 [06:36<00:00, 36.07s/it] 100%|██████████| 11/11 [06:36<00:00, 36.07s/it] >>>>>>>>>>>>>>>>>>>>>>>> >>> Step 1: generating actions ... >>> Step 1: interacting with world model ... @@ -125,6 +125,6 @@ DEBUG:PIL.Image:Importing XVThumbImagePlugin >>> Step 10: interacting with world model ... >>>>>>>>>>>>>>>>>>>>>>>> -real 8m29.264s -user 9m16.382s -sys 1m14.959s +real 8m13.634s +user 7m37.875s +sys 2m31.672s