diff --git a/scripts/evaluation/world_model_interaction.py b/scripts/evaluation/world_model_interaction.py index 7341151..4b02d3e 100644 --- a/scripts/evaluation/world_model_interaction.py +++ b/scripts/evaluation/world_model_interaction.py @@ -579,6 +579,12 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None: device = get_device_from_parameters(model) + # Fuse KV projections in attention layers (to_k + to_v → to_kv) + from unifolm_wma.modules.attention import CrossAttention + kv_count = sum(1 for m in model.modules() + if isinstance(m, CrossAttention) and m.fuse_kv()) + print(f" ✓ KV fused: {kv_count} attention layers") + # Run over data assert (args.height % 16 == 0) and ( args.width % 16 diff --git a/src/unifolm_wma/modules/attention.py b/src/unifolm_wma/modules/attention.py index cce300c..7443174 100644 --- a/src/unifolm_wma/modules/attention.py +++ b/src/unifolm_wma/modules/attention.py @@ -100,6 +100,7 @@ class CrossAttention(nn.Module): self.agent_action_context_len = agent_action_context_len self._kv_cache = {} self._kv_cache_enabled = False + self._kv_fused = False self.cross_attention_scale_learnable = cross_attention_scale_learnable if self.image_cross_attention: @@ -117,6 +118,27 @@ class CrossAttention(nn.Module): self.register_parameter('alpha_caa', nn.Parameter(torch.tensor(0.))) + def fuse_kv(self): + """Fuse to_k/to_v into to_kv (2 Linear → 1). Works for all layers.""" + k_w = self.to_k.weight # (inner_dim, context_dim) + v_w = self.to_v.weight + self.to_kv = nn.Linear(k_w.shape[1], k_w.shape[0] * 2, bias=False) + self.to_kv.weight = nn.Parameter(torch.cat([k_w, v_w], dim=0)) + del self.to_k, self.to_v + if self.image_cross_attention: + for suffix in ('_ip', '_as', '_aa'): + k_attr = f'to_k{suffix}' + v_attr = f'to_v{suffix}' + kw = getattr(self, k_attr).weight + vw = getattr(self, v_attr).weight + fused = nn.Linear(kw.shape[1], kw.shape[0] * 2, bias=False) + fused.weight = nn.Parameter(torch.cat([kw, vw], dim=0)) + setattr(self, f'to_kv{suffix}', fused) + delattr(self, k_attr) + delattr(self, v_attr) + self._kv_fused = True + return True + def forward(self, x, context=None, mask=None): spatial_self_attn = (context is None) k_ip, v_ip, out_ip = None, None, None @@ -143,19 +165,28 @@ class CrossAttention(nn.Module): self.agent_action_context_len + self.text_context_len:, :] - k = self.to_k(context_ins) - v = self.to_v(context_ins) - k_ip = self.to_k_ip(context_image) - v_ip = self.to_v_ip(context_image) - k_as = self.to_k_as(context_agent_state) - v_as = self.to_v_as(context_agent_state) - k_aa = self.to_k_aa(context_agent_action) - v_aa = self.to_v_aa(context_agent_action) + if self._kv_fused: + k, v = self.to_kv(context_ins).chunk(2, dim=-1) + k_ip, v_ip = self.to_kv_ip(context_image).chunk(2, dim=-1) + k_as, v_as = self.to_kv_as(context_agent_state).chunk(2, dim=-1) + k_aa, v_aa = self.to_kv_aa(context_agent_action).chunk(2, dim=-1) + else: + k = self.to_k(context_ins) + v = self.to_v(context_ins) + k_ip = self.to_k_ip(context_image) + v_ip = self.to_v_ip(context_image) + k_as = self.to_k_as(context_agent_state) + v_as = self.to_v_as(context_agent_state) + k_aa = self.to_k_aa(context_agent_action) + v_aa = self.to_v_aa(context_agent_action) else: if not spatial_self_attn: context = context[:, :self.text_context_len, :] - k = self.to_k(context) - v = self.to_v(context) + if self._kv_fused: + k, v = self.to_kv(context).chunk(2, dim=-1) + else: + k = self.to_k(context) + v = self.to_v(context) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) @@ -267,10 +298,14 @@ class CrossAttention(nn.Module): elif self.image_cross_attention and not spatial_self_attn: if context.shape[1] == self.text_context_len + self.video_length: context_ins, context_image = context[:, :self.text_context_len, :], context[:,self.text_context_len:, :] - k = self.to_k(context) - v = self.to_v(context) - k_ip = self.to_k_ip(context_image) - v_ip = self.to_v_ip(context_image) + if self._kv_fused: + k, v = self.to_kv(context).chunk(2, dim=-1) + k_ip, v_ip = self.to_kv_ip(context_image).chunk(2, dim=-1) + else: + k = self.to_k(context) + v = self.to_v(context) + k_ip = self.to_k_ip(context_image) + v_ip = self.to_v_ip(context_image) k, v = map(_reshape_kv, (k, v)) k_ip, v_ip = map(_reshape_kv, (k_ip, v_ip)) if use_cache: @@ -279,12 +314,17 @@ class CrossAttention(nn.Module): context_agent_state = context[:, :self.agent_state_context_len, :] context_ins = context[:, self.agent_state_context_len:self.agent_state_context_len+self.text_context_len, :] context_image = context[:, self.agent_state_context_len+self.text_context_len:, :] - k = self.to_k(context_ins) - v = self.to_v(context_ins) - k_ip = self.to_k_ip(context_image) - v_ip = self.to_v_ip(context_image) - k_as = self.to_k_as(context_agent_state) - v_as = self.to_v_as(context_agent_state) + if self._kv_fused: + k, v = self.to_kv(context_ins).chunk(2, dim=-1) + k_ip, v_ip = self.to_kv_ip(context_image).chunk(2, dim=-1) + k_as, v_as = self.to_kv_as(context_agent_state).chunk(2, dim=-1) + else: + k = self.to_k(context_ins) + v = self.to_v(context_ins) + k_ip = self.to_k_ip(context_image) + v_ip = self.to_v_ip(context_image) + k_as = self.to_k_as(context_agent_state) + v_as = self.to_v_as(context_agent_state) k, v = map(_reshape_kv, (k, v)) k_ip, v_ip = map(_reshape_kv, (k_ip, v_ip)) k_as, v_as = map(_reshape_kv, (k_as, v_as)) @@ -296,14 +336,20 @@ class CrossAttention(nn.Module): context_ins = context[:, self.agent_state_context_len+self.agent_action_context_len:self.agent_state_context_len+self.agent_action_context_len+self.text_context_len, :] context_image = context[:, self.agent_state_context_len+self.agent_action_context_len+self.text_context_len:, :] - k = self.to_k(context_ins) - v = self.to_v(context_ins) - k_ip = self.to_k_ip(context_image) - v_ip = self.to_v_ip(context_image) - k_as = self.to_k_as(context_agent_state) - v_as = self.to_v_as(context_agent_state) - k_aa = self.to_k_aa(context_agent_action) - v_aa = self.to_v_aa(context_agent_action) + if self._kv_fused: + k, v = self.to_kv(context_ins).chunk(2, dim=-1) + k_ip, v_ip = self.to_kv_ip(context_image).chunk(2, dim=-1) + k_as, v_as = self.to_kv_as(context_agent_state).chunk(2, dim=-1) + k_aa, v_aa = self.to_kv_aa(context_agent_action).chunk(2, dim=-1) + else: + k = self.to_k(context_ins) + v = self.to_v(context_ins) + k_ip = self.to_k_ip(context_image) + v_ip = self.to_v_ip(context_image) + k_as = self.to_k_as(context_agent_state) + v_as = self.to_v_as(context_agent_state) + k_aa = self.to_k_aa(context_agent_action) + v_aa = self.to_v_aa(context_agent_action) k, v = map(_reshape_kv, (k, v)) k_ip, v_ip = map(_reshape_kv, (k_ip, v_ip)) @@ -328,8 +374,11 @@ class CrossAttention(nn.Module): if not spatial_self_attn: assert 1 > 2, ">>> ERROR: you should never go into here ..." context = context[:, :self.text_context_len, :] - k = self.to_k(context) - v = self.to_v(context) + if self._kv_fused: + k, v = self.to_kv(context).chunk(2, dim=-1) + else: + k = self.to_k(context) + v = self.to_v(context) k, v = map(_reshape_kv, (k, v)) if use_cache: self._kv_cache = {'k': k, 'v': v} diff --git a/unitree_z1_dual_arm_stackbox_v2/case1/output.log b/unitree_z1_dual_arm_stackbox_v2/case1/output.log index 8f36532..6395a97 100644 --- a/unitree_z1_dual_arm_stackbox_v2/case1/output.log +++ b/unitree_z1_dual_arm_stackbox_v2/case1/output.log @@ -1,10 +1,10 @@ -2026-02-10 22:35:08.834827: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`. -2026-02-10 22:35:08.884699: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered -2026-02-10 22:35:08.884743: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered -2026-02-10 22:35:08.886076: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered -2026-02-10 22:35:08.893623: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations. +2026-02-11 11:59:27.241485: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`. +2026-02-11 11:59:27.291755: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered +2026-02-11 11:59:27.291807: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered +2026-02-11 11:59:27.293169: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered +2026-02-11 11:59:27.300838: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations. To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags. -2026-02-10 22:35:09.824417: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT +2026-02-11 11:59:28.228009: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT Global seed set to 123 >>> Loading prepared model from ckpts/unifolm_wma_dual.ckpt.prepared.pt ... >>> Prepared model loaded. @@ -25,9 +25,96 @@ INFO:root:***** Configing Data ***** >>> unitree_g1_pack_camera: data stats loaded. >>> unitree_g1_pack_camera: normalizer initiated. >>> Dataset is successfully loaded ... + ✓ KV fused: 66 attention layers >>> Generate 16 frames under each generation ... DEBUG:h5py._conv:Creating converter from 3 to 5 DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13 DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9 DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096 - 0%| | 0/11 [00:00>> Step 0: generating actions ... +>>> Step 0: interacting with world model ... +>>>>>>>>>>>>>>>>>>>>>>>> +>>> Step 1: generating actions ... +DEBUG:PIL.Image:Importing BlpImagePlugin +DEBUG:PIL.Image:Importing BmpImagePlugin +DEBUG:PIL.Image:Importing BufrStubImagePlugin +DEBUG:PIL.Image:Importing CurImagePlugin +DEBUG:PIL.Image:Importing DcxImagePlugin +DEBUG:PIL.Image:Importing DdsImagePlugin +DEBUG:PIL.Image:Importing EpsImagePlugin +DEBUG:PIL.Image:Importing FitsImagePlugin +DEBUG:PIL.Image:Importing FitsStubImagePlugin +DEBUG:PIL.Image:Importing FliImagePlugin +DEBUG:PIL.Image:Importing FpxImagePlugin +DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile' +DEBUG:PIL.Image:Importing FtexImagePlugin +DEBUG:PIL.Image:Importing GbrImagePlugin +DEBUG:PIL.Image:Importing GifImagePlugin +DEBUG:PIL.Image:Importing GribStubImagePlugin +DEBUG:PIL.Image:Importing Hdf5StubImagePlugin +DEBUG:PIL.Image:Importing IcnsImagePlugin +DEBUG:PIL.Image:Importing IcoImagePlugin +DEBUG:PIL.Image:Importing ImImagePlugin +DEBUG:PIL.Image:Importing ImtImagePlugin +DEBUG:PIL.Image:Importing IptcImagePlugin +DEBUG:PIL.Image:Importing JpegImagePlugin +DEBUG:PIL.Image:Importing Jpeg2KImagePlugin +DEBUG:PIL.Image:Importing McIdasImagePlugin +DEBUG:PIL.Image:Importing MicImagePlugin +DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile' +DEBUG:PIL.Image:Importing MpegImagePlugin +DEBUG:PIL.Image:Importing MpoImagePlugin +DEBUG:PIL.Image:Importing MspImagePlugin +DEBUG:PIL.Image:Importing PalmImagePlugin +DEBUG:PIL.Image:Importing PcdImagePlugin +DEBUG:PIL.Image:Importing PcxImagePlugin +DEBUG:PIL.Image:Importing PdfImagePlugin +DEBUG:PIL.Image:Importing PixarImagePlugin +DEBUG:PIL.Image:Importing PngImagePlugin +DEBUG:PIL.Image:Importing PpmImagePlugin +DEBUG:PIL.Image:Importing PsdImagePlugin +DEBUG:PIL.Image:Importing QoiImagePlugin +DEBUG:PIL.Image:Importing SgiImagePlugin +DEBUG:PIL.Image:Importing SpiderImagePlugin +DEBUG:PIL.Image:Importing SunImagePlugin +DEBUG:PIL.Image:Importing TgaImagePlugin +DEBUG:PIL.Image:Importing TiffImagePlugin +DEBUG:PIL.Image:Importing WebPImagePlugin +DEBUG:PIL.Image:Importing WmfImagePlugin +DEBUG:PIL.Image:Importing XbmImagePlugin +DEBUG:PIL.Image:Importing XpmImagePlugin +DEBUG:PIL.Image:Importing XVThumbImagePlugin + 18%|█▊ | 2/11 [01:08<05:07, 34.17s/it] 27%|██▋ | 3/11 [01:42<04:33, 34.16s/it] 36%|███▋ | 4/11 [02:16<03:59, 34.18s/it] 45%|████▌ | 5/11 [02:50<03:24, 34.14s/it] 55%|█████▍ | 6/11 [03:24<02:50, 34.10s/it] 64%|██████▎ | 7/11 [03:58<02:16, 34.07s/it] 73%|███████▎ | 8/11 [04:32<01:42, 34.03s/it] 82%|████████▏ | 9/11 [05:06<01:08, 34.02s/it] 91%|█████████ | 10/11 [05:40<00:34, 34.04s/it] 100%|██████████| 11/11 [06:14<00:00, 34.03s/it] 100%|██████████| 11/11 [06:14<00:00, 34.07s/it] +>>> Step 1: interacting with world model ... +>>>>>>>>>>>>>>>>>>>>>>>> +>>> Step 2: generating actions ... +>>> Step 2: interacting with world model ... +>>>>>>>>>>>>>>>>>>>>>>>> +>>> Step 3: generating actions ... +>>> Step 3: interacting with world model ... +>>>>>>>>>>>>>>>>>>>>>>>> +>>> Step 4: generating actions ... +>>> Step 4: interacting with world model ... +>>>>>>>>>>>>>>>>>>>>>>>> +>>> Step 5: generating actions ... +>>> Step 5: interacting with world model ... +>>>>>>>>>>>>>>>>>>>>>>>> +>>> Step 6: generating actions ... +>>> Step 6: interacting with world model ... +>>>>>>>>>>>>>>>>>>>>>>>> +>>> Step 7: generating actions ... +>>> Step 7: interacting with world model ... +>>>>>>>>>>>>>>>>>>>>>>>> +>>> Step 8: generating actions ... +>>> Step 8: interacting with world model ... +>>>>>>>>>>>>>>>>>>>>>>>> +>>> Step 9: generating actions ... +>>> Step 9: interacting with world model ... +>>>>>>>>>>>>>>>>>>>>>>>> +>>> Step 10: generating actions ... +>>> Step 10: interacting with world model ... +>>>>>>>>>>>>>>>>>>>>>>>> + +real 6m51.758s +user 6m23.024s +sys 1m19.488s diff --git a/unitree_z1_dual_arm_stackbox_v2/case1/psnr_result.json b/unitree_z1_dual_arm_stackbox_v2/case1/psnr_result.json index db0cd38..dec481b 100644 --- a/unitree_z1_dual_arm_stackbox_v2/case1/psnr_result.json +++ b/unitree_z1_dual_arm_stackbox_v2/case1/psnr_result.json @@ -1,5 +1,5 @@ { "gt_video": "/home/qhy/unifolm-world-model-action/unitree_z1_dual_arm_stackbox_v2/case1/unitree_z1_dual_arm_stackbox_v2_case1.mp4", "pred_video": "/home/qhy/unifolm-world-model-action/unitree_z1_dual_arm_stackbox_v2/case1/output/inference/5_full_fs4.mp4", - "psnr": 27.279678834152335 + "psnr": 28.167025381705358 } \ No newline at end of file