KV 融合实现完成。改动总结: 速度微弱提升psnr略微上升
attention.py — 3处改动: 1. __init__ 添加 _kv_fused = False 标志 2.新增 fuse_kv() 方法:将 to_k + to_v → to_kv,同时处理 _ip/_as/_aa 辅助 KV 对 2. bmm_forward 两个分支加_kv_fused 判断,用to_kv().chunk(2, dim=-1) 替代分别调用
This commit is contained in:
@@ -625,6 +625,12 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
|||||||
# Compile hot ResBlocks for operator fusion
|
# Compile hot ResBlocks for operator fusion
|
||||||
apply_torch_compile(model)
|
apply_torch_compile(model)
|
||||||
|
|
||||||
|
# Fuse KV projections in attention layers (to_k + to_v → to_kv)
|
||||||
|
from unifolm_wma.modules.attention import CrossAttention
|
||||||
|
kv_count = sum(1 for m in model.modules()
|
||||||
|
if isinstance(m, CrossAttention) and m.fuse_kv())
|
||||||
|
print(f" ✓ KV fused: {kv_count} attention layers")
|
||||||
|
|
||||||
# Export precision-converted checkpoint if requested
|
# Export precision-converted checkpoint if requested
|
||||||
if args.export_precision_ckpt:
|
if args.export_precision_ckpt:
|
||||||
export_path = args.export_precision_ckpt
|
export_path = args.export_precision_ckpt
|
||||||
|
|||||||
@@ -99,6 +99,7 @@ class CrossAttention(nn.Module):
|
|||||||
self.agent_action_context_len = agent_action_context_len
|
self.agent_action_context_len = agent_action_context_len
|
||||||
self._kv_cache = {}
|
self._kv_cache = {}
|
||||||
self._kv_cache_enabled = False
|
self._kv_cache_enabled = False
|
||||||
|
self._kv_fused = False
|
||||||
|
|
||||||
self.cross_attention_scale_learnable = cross_attention_scale_learnable
|
self.cross_attention_scale_learnable = cross_attention_scale_learnable
|
||||||
if self.image_cross_attention:
|
if self.image_cross_attention:
|
||||||
@@ -116,6 +117,27 @@ class CrossAttention(nn.Module):
|
|||||||
self.register_parameter('alpha_caa',
|
self.register_parameter('alpha_caa',
|
||||||
nn.Parameter(torch.tensor(0.)))
|
nn.Parameter(torch.tensor(0.)))
|
||||||
|
|
||||||
|
def fuse_kv(self):
|
||||||
|
"""Fuse to_k/to_v into to_kv (2 Linear → 1). Works for all layers."""
|
||||||
|
k_w = self.to_k.weight # (inner_dim, context_dim)
|
||||||
|
v_w = self.to_v.weight
|
||||||
|
self.to_kv = nn.Linear(k_w.shape[1], k_w.shape[0] * 2, bias=False)
|
||||||
|
self.to_kv.weight = nn.Parameter(torch.cat([k_w, v_w], dim=0))
|
||||||
|
del self.to_k, self.to_v
|
||||||
|
if self.image_cross_attention:
|
||||||
|
for suffix in ('_ip', '_as', '_aa'):
|
||||||
|
k_attr = f'to_k{suffix}'
|
||||||
|
v_attr = f'to_v{suffix}'
|
||||||
|
kw = getattr(self, k_attr).weight
|
||||||
|
vw = getattr(self, v_attr).weight
|
||||||
|
fused = nn.Linear(kw.shape[1], kw.shape[0] * 2, bias=False)
|
||||||
|
fused.weight = nn.Parameter(torch.cat([kw, vw], dim=0))
|
||||||
|
setattr(self, f'to_kv{suffix}', fused)
|
||||||
|
delattr(self, k_attr)
|
||||||
|
delattr(self, v_attr)
|
||||||
|
self._kv_fused = True
|
||||||
|
return True
|
||||||
|
|
||||||
def forward(self, x, context=None, mask=None):
|
def forward(self, x, context=None, mask=None):
|
||||||
spatial_self_attn = (context is None)
|
spatial_self_attn = (context is None)
|
||||||
k_ip, v_ip, out_ip = None, None, None
|
k_ip, v_ip, out_ip = None, None, None
|
||||||
@@ -276,14 +298,20 @@ class CrossAttention(nn.Module):
|
|||||||
self.agent_action_context_len +
|
self.agent_action_context_len +
|
||||||
self.text_context_len:, :]
|
self.text_context_len:, :]
|
||||||
|
|
||||||
k = self.to_k(context_ins)
|
if self._kv_fused:
|
||||||
v = self.to_v(context_ins)
|
k, v = self.to_kv(context_ins).chunk(2, dim=-1)
|
||||||
k_ip = self.to_k_ip(context_image)
|
k_ip, v_ip = self.to_kv_ip(context_image).chunk(2, dim=-1)
|
||||||
v_ip = self.to_v_ip(context_image)
|
k_as, v_as = self.to_kv_as(context_agent_state).chunk(2, dim=-1)
|
||||||
k_as = self.to_k_as(context_agent_state)
|
k_aa, v_aa = self.to_kv_aa(context_agent_action).chunk(2, dim=-1)
|
||||||
v_as = self.to_v_as(context_agent_state)
|
else:
|
||||||
k_aa = self.to_k_aa(context_agent_action)
|
k = self.to_k(context_ins)
|
||||||
v_aa = self.to_v_aa(context_agent_action)
|
v = self.to_v(context_ins)
|
||||||
|
k_ip = self.to_k_ip(context_image)
|
||||||
|
v_ip = self.to_v_ip(context_image)
|
||||||
|
k_as = self.to_k_as(context_agent_state)
|
||||||
|
v_as = self.to_v_as(context_agent_state)
|
||||||
|
k_aa = self.to_k_aa(context_agent_action)
|
||||||
|
v_aa = self.to_v_aa(context_agent_action)
|
||||||
|
|
||||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
|
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
|
||||||
(q, k, v))
|
(q, k, v))
|
||||||
@@ -304,8 +332,11 @@ class CrossAttention(nn.Module):
|
|||||||
else:
|
else:
|
||||||
if not spatial_self_attn:
|
if not spatial_self_attn:
|
||||||
context = context[:, :self.text_context_len, :]
|
context = context[:, :self.text_context_len, :]
|
||||||
k = self.to_k(context)
|
if self._kv_fused:
|
||||||
v = self.to_v(context)
|
k, v = self.to_kv(context).chunk(2, dim=-1)
|
||||||
|
else:
|
||||||
|
k = self.to_k(context)
|
||||||
|
v = self.to_v(context)
|
||||||
|
|
||||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
|
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
|
||||||
(q, k, v))
|
(q, k, v))
|
||||||
|
|||||||
@@ -1,14 +1,14 @@
|
|||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
|
||||||
__import__("pkg_resources").declare_namespace(__name__)
|
__import__("pkg_resources").declare_namespace(__name__)
|
||||||
2026-02-10 13:30:56.669605: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
2026-02-10 17:57:48.047156: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
||||||
2026-02-10 13:30:56.672987: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
2026-02-10 17:57:48.050303: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||||
2026-02-10 13:30:56.704235: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
2026-02-10 17:57:48.081710: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
||||||
2026-02-10 13:30:56.704271: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
2026-02-10 17:57:48.081741: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
||||||
2026-02-10 13:30:56.706111: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
2026-02-10 17:57:48.083577: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
||||||
2026-02-10 13:30:56.714239: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
2026-02-10 17:57:48.091772: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||||
2026-02-10 13:30:56.714546: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
2026-02-10 17:57:48.092045: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
||||||
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
||||||
2026-02-10 13:30:57.511779: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
2026-02-10 17:57:48.787960: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
||||||
[rank: 0] Global seed set to 123
|
[rank: 0] Global seed set to 123
|
||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
|
||||||
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
|
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
|
||||||
@@ -41,6 +41,7 @@ INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
|||||||
⚠ Found 601 fp32 params, converting to bf16
|
⚠ Found 601 fp32 params, converting to bf16
|
||||||
✓ All parameters converted to bfloat16
|
✓ All parameters converted to bfloat16
|
||||||
✓ torch.compile: 3 ResBlocks in output_blocks[5, 8, 9]
|
✓ torch.compile: 3 ResBlocks in output_blocks[5, 8, 9]
|
||||||
|
✓ KV fused: 66 attention layers
|
||||||
INFO:root:***** Configing Data *****
|
INFO:root:***** Configing Data *****
|
||||||
>>> unitree_z1_stackbox: 1 data samples loaded.
|
>>> unitree_z1_stackbox: 1 data samples loaded.
|
||||||
>>> unitree_z1_stackbox: data stats loaded.
|
>>> unitree_z1_stackbox: data stats loaded.
|
||||||
@@ -116,7 +117,7 @@ DEBUG:PIL.Image:Importing WmfImagePlugin
|
|||||||
DEBUG:PIL.Image:Importing WmfImagePlugin
|
DEBUG:PIL.Image:Importing WmfImagePlugin
|
||||||
DEBUG:PIL.Image:Importing XbmImagePlugin
|
DEBUG:PIL.Image:Importing XbmImagePlugin
|
||||||
DEBUG:PIL.Image:Importing XpmImagePlugin
|
DEBUG:PIL.Image:Importing XpmImagePlugin
|
||||||
DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
||||||
|
|
||||||
12%|█▎ | 1/8 [01:03<07:22, 63.25s/it]
|
12%|█▎ | 1/8 [01:03<07:22, 63.25s/it]
|
||||||
25%|██▌ | 2/8 [02:02<06:05, 60.93s/it]
|
25%|██▌ | 2/8 [02:02<06:05, 60.93s/it]
|
||||||
@@ -140,6 +141,6 @@ DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
|||||||
>>> Step 4: generating actions ...
|
>>> Step 4: generating actions ...
|
||||||
>>> Step 4: interacting with world model ...
|
>>> Step 4: interacting with world model ...
|
||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
>>> Step 5: generating actions ...
|
>>> Step 5: generating actions ...
|
||||||
>>> Step 5: interacting with world model ...
|
>>> Step 5: interacting with world model ...
|
||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
{
|
{
|
||||||
"gt_video": "/mnt/ASC1637/unifolm-world-model-action/unitree_z1_dual_arm_cleanup_pencils/case1/output/inference/unitree_z1_dual_arm_cleanup_pencils_case1_amd.mp4",
|
"gt_video": "/mnt/ASC1637/unifolm-world-model-action/unitree_z1_dual_arm_cleanup_pencils/case1/output/inference/unitree_z1_dual_arm_cleanup_pencils_case1_amd.mp4",
|
||||||
"pred_video": "/mnt/ASC1637/unifolm-world-model-action/unitree_z1_dual_arm_cleanup_pencils/case1/output/inference/0_full_fs4.mp4",
|
"pred_video": "/mnt/ASC1637/unifolm-world-model-action/unitree_z1_dual_arm_cleanup_pencils/case1/output/inference/0_full_fs4.mp4",
|
||||||
"psnr": 31.802224855380352
|
"psnr": 32.442113263955434
|
||||||
}
|
}
|
||||||
Reference in New Issue
Block a user