实现了Context 预计算和缓存功能,提升了采样效率。 psnr不下降

This commit is contained in:
qhy
2026-02-10 17:47:46 +08:00
parent 223a50f9e0
commit 9347a4ebe5
4 changed files with 117 additions and 67 deletions

View File

@@ -501,6 +501,10 @@ class ConditionalUnet1D(nn.Module):
self.last_frame_only = last_frame_only self.last_frame_only = last_frame_only
self.horizon = horizon self.horizon = horizon
# Context precomputation cache
self._global_cond_cache_enabled = False
self._global_cond_cache = {}
def forward(self, def forward(self,
sample: torch.Tensor, sample: torch.Tensor,
timestep: Union[torch.Tensor, float, int], timestep: Union[torch.Tensor, float, int],
@@ -530,14 +534,20 @@ class ConditionalUnet1D(nn.Module):
B, T, D = sample.shape B, T, D = sample.shape
if self.use_linear_act_proj: if self.use_linear_act_proj:
sample = self.proj_in_action(sample.unsqueeze(-1)) sample = self.proj_in_action(sample.unsqueeze(-1))
global_cond = self.obs_encoder(cond) _gc_key = (cond['image'].data_ptr(), cond['agent_pos'].data_ptr())
global_cond = rearrange(global_cond, if self._global_cond_cache_enabled and _gc_key in self._global_cond_cache:
'(b t) d -> b 1 (t d)', global_cond = self._global_cond_cache[_gc_key]
b=B, else:
t=self.n_obs_steps) global_cond = self.obs_encoder(cond)
global_cond = repeat(global_cond, global_cond = rearrange(global_cond,
'b c d -> b (repeat c) d', '(b t) d -> b 1 (t d)',
repeat=T) b=B,
t=self.n_obs_steps)
global_cond = repeat(global_cond,
'b c d -> b (repeat c) d',
repeat=T)
if self._global_cond_cache_enabled:
self._global_cond_cache[_gc_key] = global_cond
else: else:
sample = einops.rearrange(sample, 'b h t -> b t h') sample = einops.rearrange(sample, 'b h t -> b t h')
sample = self.proj_in_horizon(sample) sample = self.proj_in_horizon(sample)

View File

@@ -7,6 +7,7 @@ from unifolm_wma.utils.common import noise_like
from unifolm_wma.utils.common import extract_into_tensor from unifolm_wma.utils.common import extract_into_tensor
from tqdm import tqdm from tqdm import tqdm
from unifolm_wma.modules.attention import enable_cross_attn_kv_cache, disable_cross_attn_kv_cache from unifolm_wma.modules.attention import enable_cross_attn_kv_cache, disable_cross_attn_kv_cache
from unifolm_wma.modules.networks.wma_model import enable_ctx_cache, disable_ctx_cache
class DDIMSampler(object): class DDIMSampler(object):
@@ -245,6 +246,7 @@ class DDIMSampler(object):
dp_ddim_scheduler_state.set_timesteps(len(timesteps)) dp_ddim_scheduler_state.set_timesteps(len(timesteps))
ts = torch.empty((b, ), device=device, dtype=torch.long) ts = torch.empty((b, ), device=device, dtype=torch.long)
enable_cross_attn_kv_cache(self.model) enable_cross_attn_kv_cache(self.model)
enable_ctx_cache(self.model)
try: try:
for i, step in enumerate(iterator): for i, step in enumerate(iterator):
index = total_steps - i - 1 index = total_steps - i - 1
@@ -305,6 +307,7 @@ class DDIMSampler(object):
intermediates['x_inter_state'].append(state) intermediates['x_inter_state'].append(state)
finally: finally:
disable_cross_attn_kv_cache(self.model) disable_cross_attn_kv_cache(self.model)
disable_ctx_cache(self.model)
return img, action, state, intermediates return img, action, state, intermediates

View File

@@ -685,6 +685,10 @@ class WMAModel(nn.Module):
self.action_token_projector = instantiate_from_config( self.action_token_projector = instantiate_from_config(
stem_process_config) stem_process_config)
# Context precomputation cache
self._ctx_cache_enabled = False
self._ctx_cache = {}
def forward(self, def forward(self,
x: Tensor, x: Tensor,
x_action: Tensor, x_action: Tensor,
@@ -720,58 +724,64 @@ class WMAModel(nn.Module):
repeat_only=False).type(x.dtype) repeat_only=False).type(x.dtype)
emb = self.time_embed(t_emb) emb = self.time_embed(t_emb)
bt, l_context, _ = context.shape _ctx_key = context.data_ptr()
if self.base_model_gen_only: if self._ctx_cache_enabled and _ctx_key in self._ctx_cache:
assert l_context == 77 + self.n_obs_steps * 16, ">>> ERROR Context dim 1 ..." ## NOTE HANDCODE context = self._ctx_cache[_ctx_key]
else: else:
if l_context == self.n_obs_steps + 77 + t * 16: bt, l_context, _ = context.shape
context_agent_state = context[:, :self.n_obs_steps] if self.base_model_gen_only:
context_text = context[:, self.n_obs_steps:self.n_obs_steps + assert l_context == 77 + self.n_obs_steps * 16, ">>> ERROR Context dim 1 ..." ## NOTE HANDCODE
77, :] else:
context_img = context[:, self.n_obs_steps + 77:, :] if l_context == self.n_obs_steps + 77 + t * 16:
context_agent_state = context_agent_state.repeat_interleave( context_agent_state = context[:, :self.n_obs_steps]
repeats=t, dim=0) context_text = context[:, self.n_obs_steps:self.n_obs_steps +
context_text = context_text.repeat_interleave(repeats=t, dim=0) 77, :]
context_img = rearrange(context_img, context_img = context[:, self.n_obs_steps + 77:, :]
'b (t l) c -> (b t) l c', context_agent_state = context_agent_state.repeat_interleave(
t=t) repeats=t, dim=0)
context = torch.cat( context_text = context_text.repeat_interleave(repeats=t, dim=0)
[context_agent_state, context_text, context_img], dim=1) context_img = rearrange(context_img,
elif l_context == self.n_obs_steps + 16 + 77 + t * 16: 'b (t l) c -> (b t) l c',
context_agent_state = context[:, :self.n_obs_steps] t=t)
context_agent_action = context[:, self. context = torch.cat(
n_obs_steps:self.n_obs_steps + [context_agent_state, context_text, context_img], dim=1)
16, :] elif l_context == self.n_obs_steps + 16 + 77 + t * 16:
context_agent_action = rearrange( context_agent_state = context[:, :self.n_obs_steps]
context_agent_action.unsqueeze(2), 'b t l d -> (b t) l d') context_agent_action = context[:, self.
context_agent_action = self.action_token_projector( n_obs_steps:self.n_obs_steps +
context_agent_action) 16, :]
context_agent_action = rearrange(context_agent_action, context_agent_action = rearrange(
'(b o) l d -> b o l d', context_agent_action.unsqueeze(2), 'b t l d -> (b t) l d')
o=t) context_agent_action = self.action_token_projector(
context_agent_action = rearrange(context_agent_action, context_agent_action)
'b o (t l) d -> b o t l d', context_agent_action = rearrange(context_agent_action,
t=t) '(b o) l d -> b o l d',
context_agent_action = context_agent_action.permute( o=t)
0, 2, 1, 3, 4) context_agent_action = rearrange(context_agent_action,
context_agent_action = rearrange(context_agent_action, 'b o (t l) d -> b o t l d',
'b t o l d -> (b t) (o l) d') t=t)
context_agent_action = context_agent_action.permute(
0, 2, 1, 3, 4)
context_agent_action = rearrange(context_agent_action,
'b t o l d -> (b t) (o l) d')
context_text = context[:, self.n_obs_steps + context_text = context[:, self.n_obs_steps +
16:self.n_obs_steps + 16 + 77, :] 16:self.n_obs_steps + 16 + 77, :]
context_text = context_text.repeat_interleave(repeats=t, dim=0) context_text = context_text.repeat_interleave(repeats=t, dim=0)
context_img = context[:, self.n_obs_steps + 16 + 77:, :] context_img = context[:, self.n_obs_steps + 16 + 77:, :]
context_img = rearrange(context_img, context_img = rearrange(context_img,
'b (t l) c -> (b t) l c', 'b (t l) c -> (b t) l c',
t=t) t=t)
context_agent_state = context_agent_state.repeat_interleave( context_agent_state = context_agent_state.repeat_interleave(
repeats=t, dim=0) repeats=t, dim=0)
context = torch.cat([ context = torch.cat([
context_agent_state, context_agent_action, context_text, context_agent_state, context_agent_action, context_text,
context_img context_img
], ],
dim=1) dim=1)
if self._ctx_cache_enabled:
self._ctx_cache[_ctx_key] = context
emb = emb.repeat_interleave(repeats=t, dim=0) emb = emb.repeat_interleave(repeats=t, dim=0)
@@ -846,3 +856,30 @@ class WMAModel(nn.Module):
s_y = torch.zeros_like(x_state) s_y = torch.zeros_like(x_state)
return y, a_y, s_y return y, a_y, s_y
def enable_ctx_cache(model):
"""Enable context precomputation cache on WMAModel and its action/state UNets."""
for m in model.modules():
if isinstance(m, WMAModel):
m._ctx_cache_enabled = True
m._ctx_cache = {}
# conditional_unet1d cache
from unifolm_wma.models.diffusion_head.conditional_unet1d import ConditionalUnet1D
for m in model.modules():
if isinstance(m, ConditionalUnet1D):
m._global_cond_cache_enabled = True
m._global_cond_cache = {}
def disable_ctx_cache(model):
"""Disable and clear context precomputation cache."""
for m in model.modules():
if isinstance(m, WMAModel):
m._ctx_cache_enabled = False
m._ctx_cache = {}
from unifolm_wma.models.diffusion_head.conditional_unet1d import ConditionalUnet1D
for m in model.modules():
if isinstance(m, ConditionalUnet1D):
m._global_cond_cache_enabled = False
m._global_cond_cache = {}

View File

@@ -1,10 +1,10 @@
2026-02-10 17:25:35.484333: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`. 2026-02-10 17:39:22.590654: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-02-10 17:25:35.533963: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered 2026-02-10 17:39:22.640645: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2026-02-10 17:25:35.534009: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered 2026-02-10 17:39:22.640689: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2026-02-10 17:25:35.535311: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered 2026-02-10 17:39:22.642010: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2026-02-10 17:25:35.542814: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations. 2026-02-10 17:39:22.649530: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags. To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2026-02-10 17:25:36.471650: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT 2026-02-10 17:39:23.575804: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
Global seed set to 123 Global seed set to 123
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08 INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
@@ -92,7 +92,7 @@ DEBUG:PIL.Image:Importing WmfImagePlugin
DEBUG:PIL.Image:Importing WmfImagePlugin DEBUG:PIL.Image:Importing WmfImagePlugin
DEBUG:PIL.Image:Importing XbmImagePlugin DEBUG:PIL.Image:Importing XbmImagePlugin
DEBUG:PIL.Image:Importing XpmImagePlugin DEBUG:PIL.Image:Importing XpmImagePlugin
DEBUG:PIL.Image:Importing XVThumbImagePlugin DEBUG:PIL.Image:Importing XVThumbImagePlugin
9%|▉ | 1/11 [00:35<05:55, 35.52s/it] 9%|▉ | 1/11 [00:35<05:55, 35.52s/it]
18%|█▊ | 2/11 [01:11<05:21, 35.73s/it] 18%|█▊ | 2/11 [01:11<05:21, 35.73s/it]
@@ -125,6 +125,6 @@ DEBUG:PIL.Image:Importing XVThumbImagePlugin
>>> Step 6: generating actions ... >>> Step 6: generating actions ...
>>> Step 6: interacting with world model ... >>> Step 6: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>> >>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 7: generating actions ... >>> Step 7: generating actions ...
>>> Step 7: interacting with world model ... >>> Step 7: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>> >>>>>>>>>>>>>>>>>>>>>>>>