From 223a50f9e030cce709cb6199a565434dba7df7b5 Mon Sep 17 00:00:00 2001 From: qhy <2728290997@qq.com> Date: Tue, 10 Feb 2026 17:35:03 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0CrossAttention=20kv=E7=BC=93?= =?UTF-8?q?=E5=AD=98=EF=BC=8C=E5=87=8F=E5=B0=91=E9=87=8D=E5=A4=8D=E8=AE=A1?= =?UTF-8?q?=E7=AE=97=EF=BC=8C=E6=8F=90=E5=8D=87=E6=80=A7=E8=83=BD=EF=BC=8C?= =?UTF-8?q?psnr=3D25.1201dB?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .claude/settings.local.json | 10 ++ .gitignore | 2 +- src/unifolm_wma/models/samplers/ddim.py | 107 +++++++------- src/unifolm_wma/modules/attention.py | 132 +++++++++++------- .../case1/psnr_result.json | 5 + .../case1/output.log | 20 +-- .../case1/psnr_result.json | 5 + 7 files changed, 166 insertions(+), 115 deletions(-) create mode 100644 .claude/settings.local.json create mode 100644 unitree_z1_dual_arm_cleanup_pencils/case1/psnr_result.json create mode 100644 unitree_z1_dual_arm_stackbox_v2/case1/psnr_result.json diff --git a/.claude/settings.local.json b/.claude/settings.local.json new file mode 100644 index 0000000..79f6368 --- /dev/null +++ b/.claude/settings.local.json @@ -0,0 +1,10 @@ +{ + "permissions": { + "allow": [ + "Bash(conda env list:*)", + "Bash(mamba env:*)", + "Bash(micromamba env list:*)", + "Bash(echo:*)" + ] + } +} diff --git a/.gitignore b/.gitignore index 67edfeb..d6bfa21 100644 --- a/.gitignore +++ b/.gitignore @@ -120,7 +120,7 @@ localTest/ fig/ figure/ *.mp4 -*.json + Data/ControlVAE.yml Data/Misc Data/Pretrained diff --git a/src/unifolm_wma/models/samplers/ddim.py b/src/unifolm_wma/models/samplers/ddim.py index 2e88f0b..7323531 100644 --- a/src/unifolm_wma/models/samplers/ddim.py +++ b/src/unifolm_wma/models/samplers/ddim.py @@ -6,6 +6,7 @@ from unifolm_wma.utils.diffusion import make_ddim_sampling_parameters, make_ddim from unifolm_wma.utils.common import noise_like from unifolm_wma.utils.common import extract_into_tensor from tqdm import tqdm +from unifolm_wma.modules.attention import enable_cross_attn_kv_cache, disable_cross_attn_kv_cache class DDIMSampler(object): @@ -243,63 +244,67 @@ class DDIMSampler(object): dp_ddim_scheduler_action.set_timesteps(len(timesteps)) dp_ddim_scheduler_state.set_timesteps(len(timesteps)) ts = torch.empty((b, ), device=device, dtype=torch.long) - for i, step in enumerate(iterator): - index = total_steps - i - 1 - ts.fill_(step) + enable_cross_attn_kv_cache(self.model) + try: + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts.fill_(step) - # Use mask to blend noised original latent (img_orig) & new sampled latent (img) - if mask is not None: - assert x0 is not None - if clean_cond: - img_orig = x0 - else: - img_orig = self.model.q_sample(x0, ts) - img = img_orig * mask + (1. - mask) * img + # Use mask to blend noised original latent (img_orig) & new sampled latent (img) + if mask is not None: + assert x0 is not None + if clean_cond: + img_orig = x0 + else: + img_orig = self.model.q_sample(x0, ts) + img = img_orig * mask + (1. - mask) * img - outs = self.p_sample_ddim( - img, - action, - state, - cond, - ts, - index=index, - use_original_steps=ddim_use_original_steps, - quantize_denoised=quantize_denoised, - temperature=temperature, - noise_dropout=noise_dropout, - score_corrector=score_corrector, - corrector_kwargs=corrector_kwargs, - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=unconditional_conditioning, - mask=mask, - x0=x0, - fs=fs, - guidance_rescale=guidance_rescale, - **kwargs) + outs = self.p_sample_ddim( + img, + action, + state, + cond, + ts, + index=index, + use_original_steps=ddim_use_original_steps, + quantize_denoised=quantize_denoised, + temperature=temperature, + noise_dropout=noise_dropout, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + mask=mask, + x0=x0, + fs=fs, + guidance_rescale=guidance_rescale, + **kwargs) - img, pred_x0, model_output_action, model_output_state = outs + img, pred_x0, model_output_action, model_output_state = outs - action = dp_ddim_scheduler_action.step( - model_output_action, - step, - action, - generator=None, - ).prev_sample - state = dp_ddim_scheduler_state.step( - model_output_state, - step, - state, - generator=None, - ).prev_sample + action = dp_ddim_scheduler_action.step( + model_output_action, + step, + action, + generator=None, + ).prev_sample + state = dp_ddim_scheduler_state.step( + model_output_state, + step, + state, + generator=None, + ).prev_sample - if callback: callback(i) - if img_callback: img_callback(pred_x0, i) + if callback: callback(i) + if img_callback: img_callback(pred_x0, i) - if index % log_every_t == 0 or index == total_steps - 1: - intermediates['x_inter'].append(img) - intermediates['pred_x0'].append(pred_x0) - intermediates['x_inter_action'].append(action) - intermediates['x_inter_state'].append(state) + if index % log_every_t == 0 or index == total_steps - 1: + intermediates['x_inter'].append(img) + intermediates['pred_x0'].append(pred_x0) + intermediates['x_inter_action'].append(action) + intermediates['x_inter_state'].append(state) + finally: + disable_cross_attn_kv_cache(self.model) return img, action, state, intermediates diff --git a/src/unifolm_wma/modules/attention.py b/src/unifolm_wma/modules/attention.py index 0a4a703..cce300c 100644 --- a/src/unifolm_wma/modules/attention.py +++ b/src/unifolm_wma/modules/attention.py @@ -98,6 +98,9 @@ class CrossAttention(nn.Module): self.text_context_len = text_context_len self.agent_state_context_len = agent_state_context_len self.agent_action_context_len = agent_action_context_len + self._kv_cache = {} + self._kv_cache_enabled = False + self.cross_attention_scale_learnable = cross_attention_scale_learnable if self.image_cross_attention: self.to_k_ip = nn.Linear(context_dim, inner_dim, bias=False) @@ -236,17 +239,42 @@ class CrossAttention(nn.Module): k_ip, v_ip, out_ip = None, None, None k_as, v_as, out_as = None, None, None k_aa, v_aa, out_aa = None, None, None + attn_mask_aa = None + h = self.heads q = self.to_q(x) context = default(context, x) - if self.image_cross_attention and not spatial_self_attn: + b, _, _ = q.shape + q = q.unsqueeze(3).reshape(b, q.shape[1], h, self.dim_head).permute(0, 2, 1, 3).reshape(b * h, q.shape[1], self.dim_head).contiguous() + + def _reshape_kv(t): + return t.unsqueeze(3).reshape(b, t.shape[1], h, self.dim_head).permute(0, 2, 1, 3).reshape(b * h, t.shape[1], self.dim_head).contiguous() + + use_cache = self._kv_cache_enabled and not spatial_self_attn + cache_hit = use_cache and len(self._kv_cache) > 0 + + if cache_hit: + k = self._kv_cache['k'] + v = self._kv_cache['v'] + k_ip = self._kv_cache.get('k_ip') + v_ip = self._kv_cache.get('v_ip') + k_as = self._kv_cache.get('k_as') + v_as = self._kv_cache.get('v_as') + k_aa = self._kv_cache.get('k_aa') + v_aa = self._kv_cache.get('v_aa') + attn_mask_aa = self._kv_cache.get('attn_mask_aa') + elif self.image_cross_attention and not spatial_self_attn: if context.shape[1] == self.text_context_len + self.video_length: context_ins, context_image = context[:, :self.text_context_len, :], context[:,self.text_context_len:, :] k = self.to_k(context) v = self.to_v(context) k_ip = self.to_k_ip(context_image) v_ip = self.to_v_ip(context_image) + k, v = map(_reshape_kv, (k, v)) + k_ip, v_ip = map(_reshape_kv, (k_ip, v_ip)) + if use_cache: + self._kv_cache = {'k': k, 'v': v, 'k_ip': k_ip, 'v_ip': v_ip} elif context.shape[1] == self.agent_state_context_len + self.text_context_len + self.video_length: context_agent_state = context[:, :self.agent_state_context_len, :] context_ins = context[:, self.agent_state_context_len:self.agent_state_context_len+self.text_context_len, :] @@ -257,6 +285,11 @@ class CrossAttention(nn.Module): v_ip = self.to_v_ip(context_image) k_as = self.to_k_as(context_agent_state) v_as = self.to_v_as(context_agent_state) + k, v = map(_reshape_kv, (k, v)) + k_ip, v_ip = map(_reshape_kv, (k_ip, v_ip)) + k_as, v_as = map(_reshape_kv, (k_as, v_as)) + if use_cache: + self._kv_cache = {'k': k, 'v': v, 'k_ip': k_ip, 'v_ip': v_ip, 'k_as': k_as, 'v_as': v_as} else: context_agent_state = context[:, :self.agent_state_context_len, :] context_agent_action = context[:, self.agent_state_context_len:self.agent_state_context_len+self.agent_action_context_len, :] @@ -272,99 +305,78 @@ class CrossAttention(nn.Module): k_aa = self.to_k_aa(context_agent_action) v_aa = self.to_v_aa(context_agent_action) - attn_mask_aa = self._get_attn_mask_aa(x.shape[0], - q.shape[1], - k_aa.shape[1], - block_size=16, - device=k_aa.device) + k, v = map(_reshape_kv, (k, v)) + k_ip, v_ip = map(_reshape_kv, (k_ip, v_ip)) + k_as, v_as = map(_reshape_kv, (k_as, v_as)) + k_aa, v_aa = map(_reshape_kv, (k_aa, v_aa)) + + attn_mask_aa_raw = self._get_attn_mask_aa(x.shape[0], + q.shape[1], + k_aa.shape[1], + block_size=16, + device=k_aa.device) + attn_mask_aa = attn_mask_aa_raw.unsqueeze(1).repeat(1, h, 1, 1).reshape( + b * h, attn_mask_aa_raw.shape[1], attn_mask_aa_raw.shape[2]).to(q.dtype) + + if use_cache: + self._kv_cache = { + 'k': k, 'v': v, 'k_ip': k_ip, 'v_ip': v_ip, + 'k_as': k_as, 'v_as': v_as, 'k_aa': k_aa, 'v_aa': v_aa, + 'attn_mask_aa': attn_mask_aa, + } else: if not spatial_self_attn: assert 1 > 2, ">>> ERROR: you should never go into here ..." context = context[:, :self.text_context_len, :] k = self.to_k(context) v = self.to_v(context) - - b, _, _ = q.shape - q = q.unsqueeze(3).reshape(b, q.shape[1], self.heads, self.dim_head).permute(0, 2, 1, 3).reshape(b * self.heads, q.shape[1], self.dim_head).contiguous() + k, v = map(_reshape_kv, (k, v)) + if use_cache: + self._kv_cache = {'k': k, 'v': v} if k is not None: - k, v = map( - lambda t: t.unsqueeze(3).reshape(b, t.shape[ - 1], self.heads, self.dim_head).permute(0, 2, 1, 3).reshape( - b * self.heads, t.shape[1], self.dim_head).contiguous(), - (k, v), - ) out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=None) out = (out.unsqueeze(0).reshape( - b, self.heads, out.shape[1], + b, h, out.shape[1], self.dim_head).permute(0, 2, 1, 3).reshape(b, out.shape[1], - self.heads * self.dim_head)) + h * self.dim_head)) if k_ip is not None: - # For image cross-attention - k_ip, v_ip = map( - lambda t: t.unsqueeze(3).reshape(b, t.shape[ - 1], self.heads, self.dim_head).permute(0, 2, 1, 3).reshape( - b * self.heads, t.shape[1], self.dim_head).contiguous( - ), - (k_ip, v_ip), - ) out_ip = xformers.ops.memory_efficient_attention(q, k_ip, v_ip, attn_bias=None, op=None) out_ip = (out_ip.unsqueeze(0).reshape( - b, self.heads, out_ip.shape[1], + b, h, out_ip.shape[1], self.dim_head).permute(0, 2, 1, 3).reshape(b, out_ip.shape[1], - self.heads * self.dim_head)) + h * self.dim_head)) if k_as is not None: - # For agent state cross-attention - k_as, v_as = map( - lambda t: t.unsqueeze(3).reshape(b, t.shape[ - 1], self.heads, self.dim_head).permute(0, 2, 1, 3).reshape( - b * self.heads, t.shape[1], self.dim_head).contiguous( - ), - (k_as, v_as), - ) out_as = xformers.ops.memory_efficient_attention(q, k_as, v_as, attn_bias=None, op=None) out_as = (out_as.unsqueeze(0).reshape( - b, self.heads, out_as.shape[1], + b, h, out_as.shape[1], self.dim_head).permute(0, 2, 1, 3).reshape(b, out_as.shape[1], - self.heads * self.dim_head)) + h * self.dim_head)) + if k_aa is not None: - # For agent action cross-attention - k_aa, v_aa = map( - lambda t: t.unsqueeze(3).reshape(b, t.shape[ - 1], self.heads, self.dim_head).permute(0, 2, 1, 3).reshape( - b * self.heads, t.shape[1], self.dim_head).contiguous( - ), - (k_aa, v_aa), - ) - - attn_mask_aa = attn_mask_aa.unsqueeze(1).repeat(1,self.heads,1,1).reshape( - b * self.heads, attn_mask_aa.shape[1], attn_mask_aa.shape[2]) - attn_mask_aa = attn_mask_aa.to(q.dtype) - out_aa = xformers.ops.memory_efficient_attention( q, k_aa, v_aa, attn_bias=attn_mask_aa, op=None) - out_aa = (out_aa.unsqueeze(0).reshape( - b, self.heads, out_aa.shape[1], + b, h, out_aa.shape[1], self.dim_head).permute(0, 2, 1, 3).reshape(b, out_aa.shape[1], - self.heads * self.dim_head)) + h * self.dim_head)) if exists(mask): raise NotImplementedError @@ -410,6 +422,20 @@ class CrossAttention(nn.Module): return attn_mask +def enable_cross_attn_kv_cache(module): + for m in module.modules(): + if isinstance(m, CrossAttention): + m._kv_cache_enabled = True + m._kv_cache = {} + + +def disable_cross_attn_kv_cache(module): + for m in module.modules(): + if isinstance(m, CrossAttention): + m._kv_cache_enabled = False + m._kv_cache = {} + + class BasicTransformerBlock(nn.Module): def __init__(self, diff --git a/unitree_z1_dual_arm_cleanup_pencils/case1/psnr_result.json b/unitree_z1_dual_arm_cleanup_pencils/case1/psnr_result.json new file mode 100644 index 0000000..033b2da --- /dev/null +++ b/unitree_z1_dual_arm_cleanup_pencils/case1/psnr_result.json @@ -0,0 +1,5 @@ +{ + "gt_video": "unitree_z1_dual_arm_cleanup_pencils/case1/unitree_z1_dual_arm_cleanup_pencils_case1.mp4", + "pred_video": "unitree_z1_dual_arm_cleanup_pencils/case1/output/inference/0_full_fs4.mp4", + "psnr": 47.911564449209735 +} \ No newline at end of file diff --git a/unitree_z1_dual_arm_stackbox_v2/case1/output.log b/unitree_z1_dual_arm_stackbox_v2/case1/output.log index af627c1..72b254d 100644 --- a/unitree_z1_dual_arm_stackbox_v2/case1/output.log +++ b/unitree_z1_dual_arm_stackbox_v2/case1/output.log @@ -1,10 +1,10 @@ -2026-02-10 17:03:42.057881: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`. -2026-02-10 17:03:42.107520: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered -2026-02-10 17:03:42.107564: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered -2026-02-10 17:03:42.108900: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered -2026-02-10 17:03:42.116404: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations. +2026-02-10 17:25:35.484333: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`. +2026-02-10 17:25:35.533963: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered +2026-02-10 17:25:35.534009: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered +2026-02-10 17:25:35.535311: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered +2026-02-10 17:25:35.542814: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations. To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags. -2026-02-10 17:03:43.044539: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT +2026-02-10 17:25:36.471650: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT Global seed set to 123 INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08 @@ -92,7 +92,7 @@ DEBUG:PIL.Image:Importing WmfImagePlugin DEBUG:PIL.Image:Importing XbmImagePlugin DEBUG:PIL.Image:Importing XpmImagePlugin DEBUG:PIL.Image:Importing XVThumbImagePlugin - 9%|▉ | 1/11 [00:37<06:15, 37.55s/it] 18%|█▊ | 2/11 [01:15<05:39, 37.71s/it] 27%|██▋ | 3/11 [01:53<05:03, 37.98s/it] 36%|███▋ | 4/11 [02:32<04:26, 38.13s/it] 45%|████▌ | 5/11 [03:10<03:48, 38.14s/it] 55%|█████▍ | 6/11 [03:48<03:10, 38.07s/it] 64%|██████▎ | 7/11 [04:26<02:32, 38.02s/it] 73%|███████▎ | 8/11 [05:04<01:54, 38.01s/it] 82%|████████▏ | 9/11 [05:41<01:15, 37.99s/it] 91%|█████████ | 10/11 [06:19<00:37, 37.99s/it] 100%|██████████| 11/11 [06:57<00:00, 38.00s/it] 100%|██████████| 11/11 [06:57<00:00, 38.00s/it] + 9%|▉ | 1/11 [00:36<06:07, 36.77s/it] 18%|█▊ | 2/11 [01:13<05:32, 36.92s/it] 27%|██▋ | 3/11 [01:51<04:57, 37.20s/it] 36%|███▋ | 4/11 [02:28<04:21, 37.35s/it] 45%|████▌ | 5/11 [03:06<03:44, 37.39s/it] 55%|█████▍ | 6/11 [03:43<03:06, 37.31s/it] 64%|██████▎ | 7/11 [04:20<02:29, 37.26s/it] 73%|███████▎ | 8/11 [04:57<01:51, 37.24s/it] 82%|████████▏ | 9/11 [05:35<01:14, 37.22s/it] 91%|█████████ | 10/11 [06:12<00:37, 37.23s/it] 100%|██████████| 11/11 [06:49<00:00, 37.23s/it] 100%|██████████| 11/11 [06:49<00:00, 37.23s/it] >>>>>>>>>>>>>>>>>>>>>>>> >>> Step 1: generating actions ... >>> Step 1: interacting with world model ... @@ -125,6 +125,6 @@ DEBUG:PIL.Image:Importing XVThumbImagePlugin >>> Step 10: interacting with world model ... >>>>>>>>>>>>>>>>>>>>>>>> -real 8m36.548s -user 9m22.484s -sys 1m21.506s +real 8m29.264s +user 9m16.382s +sys 1m14.959s diff --git a/unitree_z1_dual_arm_stackbox_v2/case1/psnr_result.json b/unitree_z1_dual_arm_stackbox_v2/case1/psnr_result.json new file mode 100644 index 0000000..af6ad95 --- /dev/null +++ b/unitree_z1_dual_arm_stackbox_v2/case1/psnr_result.json @@ -0,0 +1,5 @@ +{ + "gt_video": "/home/qhy/unifolm-world-model-action/unitree_z1_dual_arm_stackbox_v2/case1/unitree_z1_dual_arm_stackbox_v2_case1.mp4", + "pred_video": "/home/qhy/unifolm-world-model-action/unitree_z1_dual_arm_stackbox_v2/case1/output/inference/5_full_fs4.mp4", + "psnr": 25.12008483689618 +} \ No newline at end of file