From 57ba85d147950b716406e82849b1b69500dd26a1 Mon Sep 17 00:00:00 2001 From: olivame Date: Tue, 10 Feb 2026 18:07:23 +0000 Subject: [PATCH] =?UTF-8?q?KV=20=E8=9E=8D=E5=90=88=E5=AE=9E=E7=8E=B0?= =?UTF-8?q?=E5=AE=8C=E6=88=90=E3=80=82=E6=94=B9=E5=8A=A8=E6=80=BB=E7=BB=93?= =?UTF-8?q?=EF=BC=9A=20=E9=80=9F=E5=BA=A6=E5=BE=AE=E5=BC=B1=E6=8F=90?= =?UTF-8?q?=E5=8D=87psnr=E7=95=A5=E5=BE=AE=E4=B8=8A=E5=8D=87=20=20=20atten?= =?UTF-8?q?tion.py=20=E2=80=94=203=E5=A4=84=E6=94=B9=E5=8A=A8=EF=BC=9A=20?= =?UTF-8?q?=20=201.=20=5F=5Finit=5F=5F=20=E6=B7=BB=E5=8A=A0=20=5Fkv=5Ffuse?= =?UTF-8?q?d=20=3D=20False=20=E6=A0=87=E5=BF=97=20=20=202.=E6=96=B0?= =?UTF-8?q?=E5=A2=9E=20fuse=5Fkv()=20=E6=96=B9=E6=B3=95=EF=BC=9A=E5=B0=86?= =?UTF-8?q?=20to=5Fk=20+=20to=5Fv=20=E2=86=92=20to=5Fkv=EF=BC=8C=E5=90=8C?= =?UTF-8?q?=E6=97=B6=E5=A4=84=E7=90=86=20=5Fip/=5Fas/=5Faa=20=E8=BE=85?= =?UTF-8?q?=E5=8A=A9=20KV=20=E5=AF=B9=20=20=202.=20bmm=5Fforward=20?= =?UTF-8?q?=E4=B8=A4=E4=B8=AA=E5=88=86=E6=94=AF=E5=8A=A0=5Fkv=5Ffused=20?= =?UTF-8?q?=E5=88=A4=E6=96=AD=EF=BC=8C=E7=94=A8to=5Fkv().chunk(2,=20dim=3D?= =?UTF-8?q?-1)=20=E6=9B=BF=E4=BB=A3=E5=88=86=E5=88=AB=E8=B0=83=E7=94=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- scripts/evaluation/world_model_interaction.py | 6 +++ src/unifolm_wma/modules/attention.py | 51 +++++++++++++++---- .../case1/output.log | 25 ++++----- .../case1/psnr_result1.json | 2 +- 4 files changed, 61 insertions(+), 23 deletions(-) diff --git a/scripts/evaluation/world_model_interaction.py b/scripts/evaluation/world_model_interaction.py index 3270dda..cb25f2e 100644 --- a/scripts/evaluation/world_model_interaction.py +++ b/scripts/evaluation/world_model_interaction.py @@ -625,6 +625,12 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None: # Compile hot ResBlocks for operator fusion apply_torch_compile(model) + # Fuse KV projections in attention layers (to_k + to_v → to_kv) + from unifolm_wma.modules.attention import CrossAttention + kv_count = sum(1 for m in model.modules() + if isinstance(m, CrossAttention) and m.fuse_kv()) + print(f" ✓ KV fused: {kv_count} attention layers") + # Export precision-converted checkpoint if requested if args.export_precision_ckpt: export_path = args.export_precision_ckpt diff --git a/src/unifolm_wma/modules/attention.py b/src/unifolm_wma/modules/attention.py index 248a1f6..0499a53 100644 --- a/src/unifolm_wma/modules/attention.py +++ b/src/unifolm_wma/modules/attention.py @@ -99,6 +99,7 @@ class CrossAttention(nn.Module): self.agent_action_context_len = agent_action_context_len self._kv_cache = {} self._kv_cache_enabled = False + self._kv_fused = False self.cross_attention_scale_learnable = cross_attention_scale_learnable if self.image_cross_attention: @@ -116,6 +117,27 @@ class CrossAttention(nn.Module): self.register_parameter('alpha_caa', nn.Parameter(torch.tensor(0.))) + def fuse_kv(self): + """Fuse to_k/to_v into to_kv (2 Linear → 1). Works for all layers.""" + k_w = self.to_k.weight # (inner_dim, context_dim) + v_w = self.to_v.weight + self.to_kv = nn.Linear(k_w.shape[1], k_w.shape[0] * 2, bias=False) + self.to_kv.weight = nn.Parameter(torch.cat([k_w, v_w], dim=0)) + del self.to_k, self.to_v + if self.image_cross_attention: + for suffix in ('_ip', '_as', '_aa'): + k_attr = f'to_k{suffix}' + v_attr = f'to_v{suffix}' + kw = getattr(self, k_attr).weight + vw = getattr(self, v_attr).weight + fused = nn.Linear(kw.shape[1], kw.shape[0] * 2, bias=False) + fused.weight = nn.Parameter(torch.cat([kw, vw], dim=0)) + setattr(self, f'to_kv{suffix}', fused) + delattr(self, k_attr) + delattr(self, v_attr) + self._kv_fused = True + return True + def forward(self, x, context=None, mask=None): spatial_self_attn = (context is None) k_ip, v_ip, out_ip = None, None, None @@ -276,14 +298,20 @@ class CrossAttention(nn.Module): self.agent_action_context_len + self.text_context_len:, :] - k = self.to_k(context_ins) - v = self.to_v(context_ins) - k_ip = self.to_k_ip(context_image) - v_ip = self.to_v_ip(context_image) - k_as = self.to_k_as(context_agent_state) - v_as = self.to_v_as(context_agent_state) - k_aa = self.to_k_aa(context_agent_action) - v_aa = self.to_v_aa(context_agent_action) + if self._kv_fused: + k, v = self.to_kv(context_ins).chunk(2, dim=-1) + k_ip, v_ip = self.to_kv_ip(context_image).chunk(2, dim=-1) + k_as, v_as = self.to_kv_as(context_agent_state).chunk(2, dim=-1) + k_aa, v_aa = self.to_kv_aa(context_agent_action).chunk(2, dim=-1) + else: + k = self.to_k(context_ins) + v = self.to_v(context_ins) + k_ip = self.to_k_ip(context_image) + v_ip = self.to_v_ip(context_image) + k_as = self.to_k_as(context_agent_state) + v_as = self.to_v_as(context_agent_state) + k_aa = self.to_k_aa(context_agent_action) + v_aa = self.to_v_aa(context_agent_action) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) @@ -304,8 +332,11 @@ class CrossAttention(nn.Module): else: if not spatial_self_attn: context = context[:, :self.text_context_len, :] - k = self.to_k(context) - v = self.to_v(context) + if self._kv_fused: + k, v = self.to_kv(context).chunk(2, dim=-1) + else: + k = self.to_k(context) + v = self.to_v(context) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) diff --git a/unitree_z1_dual_arm_cleanup_pencils/case1/output.log b/unitree_z1_dual_arm_cleanup_pencils/case1/output.log index e61c819..779fdcc 100644 --- a/unitree_z1_dual_arm_cleanup_pencils/case1/output.log +++ b/unitree_z1_dual_arm_cleanup_pencils/case1/output.log @@ -1,14 +1,14 @@ /mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81. __import__("pkg_resources").declare_namespace(__name__) -2026-02-10 13:30:56.669605: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`. -2026-02-10 13:30:56.672987: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used. -2026-02-10 13:30:56.704235: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered -2026-02-10 13:30:56.704271: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered -2026-02-10 13:30:56.706111: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered -2026-02-10 13:30:56.714239: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used. -2026-02-10 13:30:56.714546: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations. +2026-02-10 17:57:48.047156: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`. +2026-02-10 17:57:48.050303: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used. +2026-02-10 17:57:48.081710: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered +2026-02-10 17:57:48.081741: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered +2026-02-10 17:57:48.083577: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered +2026-02-10 17:57:48.091772: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used. +2026-02-10 17:57:48.092045: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations. To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags. -2026-02-10 13:30:57.511779: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT +2026-02-10 17:57:48.787960: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT [rank: 0] Global seed set to 123 /mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead. @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32) @@ -41,6 +41,7 @@ INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k). ⚠ Found 601 fp32 params, converting to bf16 ✓ All parameters converted to bfloat16 ✓ torch.compile: 3 ResBlocks in output_blocks[5, 8, 9] + ✓ KV fused: 66 attention layers INFO:root:***** Configing Data ***** >>> unitree_z1_stackbox: 1 data samples loaded. >>> unitree_z1_stackbox: data stats loaded. @@ -116,7 +117,7 @@ DEBUG:PIL.Image:Importing WmfImagePlugin DEBUG:PIL.Image:Importing XbmImagePlugin DEBUG:PIL.Image:Importing XpmImagePlugin DEBUG:PIL.Image:Importing XVThumbImagePlugin - 12%|█▎ | 1/8 [01:02<07:20, 63.00s/it] 25%|██▌ | 2/8 [02:02<06:05, 60.84s/it] 38%|███▊ | 3/8 [03:01<05:00, 60.16s/it] 50%|█████ | 4/8 [04:01<03:59, 59.91s/it] 62%|██████▎ | 5/8 [05:00<02:59, 59.80s/it] 75%|███████▌ | 6/8 [06:00<01:59, 59.66s/it] 88%|████████▊ | 7/8 [06:59<00:59, 59.53s/it] 100%|██████████| 8/8 [07:58<00:00, 59.49s/it] 100%|██████████| 8/8 [07:58<00:00, 59.86s/it] + 12%|█▎ | 1/8 [01:03<07:22, 63.25s/it] 25%|██▌ | 2/8 [02:02<06:05, 60.93s/it] 38%|███▊ | 3/8 [03:01<05:00, 60.19s/it] 50%|█████ | 4/8 [04:01<03:59, 59.85s/it] 62%|██████▎ | 5/8 [05:00<02:59, 59.69s/it] 75%|███████▌ | 6/8 [05:59<01:59, 59.54s/it] 88%|████████▊ | 7/8 [06:59<00:59, 59.43s/it] 100%|██████████| 8/8 [07:58<00:00, 59.46s/it] 100%|██████████| 8/8 [07:58<00:00, 59.82s/it] >>>>>>>>>>>>>>>>>>>>>>>> >>> Step 1: generating actions ... >>> Step 1: interacting with world model ... @@ -140,6 +141,6 @@ DEBUG:PIL.Image:Importing XVThumbImagePlugin >>> Step 7: interacting with world model ... >>>>>>>>>>>>>>>>>>>>>>>> -real 9m15.361s -user 10m51.680s -sys 1m10.602s +real 9m13.133s +user 11m35.465s +sys 1m9.437s diff --git a/unitree_z1_dual_arm_cleanup_pencils/case1/psnr_result1.json b/unitree_z1_dual_arm_cleanup_pencils/case1/psnr_result1.json index 5d699db..0aeb56e 100644 --- a/unitree_z1_dual_arm_cleanup_pencils/case1/psnr_result1.json +++ b/unitree_z1_dual_arm_cleanup_pencils/case1/psnr_result1.json @@ -1,5 +1,5 @@ { "gt_video": "/mnt/ASC1637/unifolm-world-model-action/unitree_z1_dual_arm_cleanup_pencils/case1/output/inference/unitree_z1_dual_arm_cleanup_pencils_case1_amd.mp4", "pred_video": "/mnt/ASC1637/unifolm-world-model-action/unitree_z1_dual_arm_cleanup_pencils/case1/output/inference/0_full_fs4.mp4", - "psnr": 31.802224855380352 + "psnr": 32.442113263955434 } \ No newline at end of file