实现了Context 预计算和缓存功能,提升了采样效率。 psnr不下降

This commit is contained in:
2026-02-09 17:42:47 +00:00
parent f192c8aca9
commit 6dca3696d8
4 changed files with 119 additions and 69 deletions

View File

@@ -501,6 +501,10 @@ class ConditionalUnet1D(nn.Module):
self.last_frame_only = last_frame_only
self.horizon = horizon
# Context precomputation cache
self._global_cond_cache_enabled = False
self._global_cond_cache = {}
def forward(self,
sample: torch.Tensor,
timestep: Union[torch.Tensor, float, int],
@@ -530,6 +534,10 @@ class ConditionalUnet1D(nn.Module):
B, T, D = sample.shape
if self.use_linear_act_proj:
sample = self.proj_in_action(sample.unsqueeze(-1))
_gc_key = (cond['image'].data_ptr(), cond['agent_pos'].data_ptr())
if self._global_cond_cache_enabled and _gc_key in self._global_cond_cache:
global_cond = self._global_cond_cache[_gc_key]
else:
global_cond = self.obs_encoder(cond)
global_cond = rearrange(global_cond,
'(b t) d -> b 1 (t d)',
@@ -538,6 +546,8 @@ class ConditionalUnet1D(nn.Module):
global_cond = repeat(global_cond,
'b c d -> b (repeat c) d',
repeat=T)
if self._global_cond_cache_enabled:
self._global_cond_cache[_gc_key] = global_cond
else:
sample = einops.rearrange(sample, 'b h t -> b t h')
sample = self.proj_in_horizon(sample)

View File

@@ -7,6 +7,7 @@ from unifolm_wma.utils.common import noise_like
from unifolm_wma.utils.common import extract_into_tensor
from tqdm import tqdm
from unifolm_wma.modules.attention import enable_cross_attn_kv_cache, disable_cross_attn_kv_cache
from unifolm_wma.modules.networks.wma_model import enable_ctx_cache, disable_ctx_cache
class DDIMSampler(object):
@@ -245,6 +246,7 @@ class DDIMSampler(object):
dp_ddim_scheduler_state.set_timesteps(len(timesteps))
ts = torch.empty((b, ), device=device, dtype=torch.long)
enable_cross_attn_kv_cache(self.model)
enable_ctx_cache(self.model)
try:
for i, step in enumerate(iterator):
index = total_steps - i - 1
@@ -305,6 +307,7 @@ class DDIMSampler(object):
intermediates['x_inter_state'].append(state)
finally:
disable_cross_attn_kv_cache(self.model)
disable_ctx_cache(self.model)
return img, action, state, intermediates

View File

@@ -685,6 +685,10 @@ class WMAModel(nn.Module):
self.action_token_projector = instantiate_from_config(
stem_process_config)
# Context precomputation cache
self._ctx_cache_enabled = False
self._ctx_cache = {}
def forward(self,
x: Tensor,
x_action: Tensor,
@@ -720,6 +724,10 @@ class WMAModel(nn.Module):
repeat_only=False).type(x.dtype)
emb = self.time_embed(t_emb)
_ctx_key = context.data_ptr()
if self._ctx_cache_enabled and _ctx_key in self._ctx_cache:
context = self._ctx_cache[_ctx_key]
else:
bt, l_context, _ = context.shape
if self.base_model_gen_only:
assert l_context == 77 + self.n_obs_steps * 16, ">>> ERROR Context dim 1 ..." ## NOTE HANDCODE
@@ -772,6 +780,8 @@ class WMAModel(nn.Module):
context_img
],
dim=1)
if self._ctx_cache_enabled:
self._ctx_cache[_ctx_key] = context
emb = emb.repeat_interleave(repeats=t, dim=0)
@@ -846,3 +856,30 @@ class WMAModel(nn.Module):
s_y = torch.zeros_like(x_state)
return y, a_y, s_y
def enable_ctx_cache(model):
"""Enable context precomputation cache on WMAModel and its action/state UNets."""
for m in model.modules():
if isinstance(m, WMAModel):
m._ctx_cache_enabled = True
m._ctx_cache = {}
# conditional_unet1d cache
from unifolm_wma.models.diffusion_head.conditional_unet1d import ConditionalUnet1D
for m in model.modules():
if isinstance(m, ConditionalUnet1D):
m._global_cond_cache_enabled = True
m._global_cond_cache = {}
def disable_ctx_cache(model):
"""Disable and clear context precomputation cache."""
for m in model.modules():
if isinstance(m, WMAModel):
m._ctx_cache_enabled = False
m._ctx_cache = {}
from unifolm_wma.models.diffusion_head.conditional_unet1d import ConditionalUnet1D
for m in model.modules():
if isinstance(m, ConditionalUnet1D):
m._global_cond_cache_enabled = False
m._global_cond_cache = {}

View File

@@ -1,14 +1,14 @@
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
__import__("pkg_resources").declare_namespace(__name__)
2026-02-09 16:53:59.556813: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-02-09 16:53:59.559892: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-09 16:53:59.591414: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2026-02-09 16:53:59.591446: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2026-02-09 16:53:59.593281: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2026-02-09 16:53:59.601486: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-09 16:53:59.601838: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
2026-02-09 17:32:41.850068: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-02-09 17:32:41.853132: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-09 17:32:41.886058: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2026-02-09 17:32:41.886103: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2026-02-09 17:32:41.887979: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2026-02-09 17:32:41.896994: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-09 17:32:41.897283: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2026-02-09 16:54:00.228108: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
2026-02-09 17:32:42.611394: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
[rank: 0] Global seed set to 123
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
@@ -116,7 +116,7 @@ DEBUG:PIL.Image:Importing WmfImagePlugin
DEBUG:PIL.Image:Importing WmfImagePlugin
DEBUG:PIL.Image:Importing XbmImagePlugin
DEBUG:PIL.Image:Importing XpmImagePlugin
DEBUG:PIL.Image:Importing XVThumbImagePlugin
DEBUG:PIL.Image:Importing XVThumbImagePlugin
12%|█▎ | 1/8 [01:08<08:00, 68.62s/it]
25%|██▌ | 2/8 [02:13<06:38, 66.41s/it]
@@ -140,6 +140,6 @@ DEBUG:PIL.Image:Importing XVThumbImagePlugin
>>> Step 4: generating actions ...
>>> Step 4: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 5: generating actions ...
>>> Step 5: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 5: generating actions ...
>>> Step 5: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>