实现了Context 预计算和缓存功能,提升了采样效率。 psnr不下降
This commit is contained in:
@@ -501,6 +501,10 @@ class ConditionalUnet1D(nn.Module):
|
|||||||
self.last_frame_only = last_frame_only
|
self.last_frame_only = last_frame_only
|
||||||
self.horizon = horizon
|
self.horizon = horizon
|
||||||
|
|
||||||
|
# Context precomputation cache
|
||||||
|
self._global_cond_cache_enabled = False
|
||||||
|
self._global_cond_cache = {}
|
||||||
|
|
||||||
def forward(self,
|
def forward(self,
|
||||||
sample: torch.Tensor,
|
sample: torch.Tensor,
|
||||||
timestep: Union[torch.Tensor, float, int],
|
timestep: Union[torch.Tensor, float, int],
|
||||||
@@ -530,6 +534,10 @@ class ConditionalUnet1D(nn.Module):
|
|||||||
B, T, D = sample.shape
|
B, T, D = sample.shape
|
||||||
if self.use_linear_act_proj:
|
if self.use_linear_act_proj:
|
||||||
sample = self.proj_in_action(sample.unsqueeze(-1))
|
sample = self.proj_in_action(sample.unsqueeze(-1))
|
||||||
|
_gc_key = (cond['image'].data_ptr(), cond['agent_pos'].data_ptr())
|
||||||
|
if self._global_cond_cache_enabled and _gc_key in self._global_cond_cache:
|
||||||
|
global_cond = self._global_cond_cache[_gc_key]
|
||||||
|
else:
|
||||||
global_cond = self.obs_encoder(cond)
|
global_cond = self.obs_encoder(cond)
|
||||||
global_cond = rearrange(global_cond,
|
global_cond = rearrange(global_cond,
|
||||||
'(b t) d -> b 1 (t d)',
|
'(b t) d -> b 1 (t d)',
|
||||||
@@ -538,6 +546,8 @@ class ConditionalUnet1D(nn.Module):
|
|||||||
global_cond = repeat(global_cond,
|
global_cond = repeat(global_cond,
|
||||||
'b c d -> b (repeat c) d',
|
'b c d -> b (repeat c) d',
|
||||||
repeat=T)
|
repeat=T)
|
||||||
|
if self._global_cond_cache_enabled:
|
||||||
|
self._global_cond_cache[_gc_key] = global_cond
|
||||||
else:
|
else:
|
||||||
sample = einops.rearrange(sample, 'b h t -> b t h')
|
sample = einops.rearrange(sample, 'b h t -> b t h')
|
||||||
sample = self.proj_in_horizon(sample)
|
sample = self.proj_in_horizon(sample)
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from unifolm_wma.utils.common import noise_like
|
|||||||
from unifolm_wma.utils.common import extract_into_tensor
|
from unifolm_wma.utils.common import extract_into_tensor
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from unifolm_wma.modules.attention import enable_cross_attn_kv_cache, disable_cross_attn_kv_cache
|
from unifolm_wma.modules.attention import enable_cross_attn_kv_cache, disable_cross_attn_kv_cache
|
||||||
|
from unifolm_wma.modules.networks.wma_model import enable_ctx_cache, disable_ctx_cache
|
||||||
|
|
||||||
|
|
||||||
class DDIMSampler(object):
|
class DDIMSampler(object):
|
||||||
@@ -245,6 +246,7 @@ class DDIMSampler(object):
|
|||||||
dp_ddim_scheduler_state.set_timesteps(len(timesteps))
|
dp_ddim_scheduler_state.set_timesteps(len(timesteps))
|
||||||
ts = torch.empty((b, ), device=device, dtype=torch.long)
|
ts = torch.empty((b, ), device=device, dtype=torch.long)
|
||||||
enable_cross_attn_kv_cache(self.model)
|
enable_cross_attn_kv_cache(self.model)
|
||||||
|
enable_ctx_cache(self.model)
|
||||||
try:
|
try:
|
||||||
for i, step in enumerate(iterator):
|
for i, step in enumerate(iterator):
|
||||||
index = total_steps - i - 1
|
index = total_steps - i - 1
|
||||||
@@ -305,6 +307,7 @@ class DDIMSampler(object):
|
|||||||
intermediates['x_inter_state'].append(state)
|
intermediates['x_inter_state'].append(state)
|
||||||
finally:
|
finally:
|
||||||
disable_cross_attn_kv_cache(self.model)
|
disable_cross_attn_kv_cache(self.model)
|
||||||
|
disable_ctx_cache(self.model)
|
||||||
|
|
||||||
return img, action, state, intermediates
|
return img, action, state, intermediates
|
||||||
|
|
||||||
|
|||||||
@@ -685,6 +685,10 @@ class WMAModel(nn.Module):
|
|||||||
self.action_token_projector = instantiate_from_config(
|
self.action_token_projector = instantiate_from_config(
|
||||||
stem_process_config)
|
stem_process_config)
|
||||||
|
|
||||||
|
# Context precomputation cache
|
||||||
|
self._ctx_cache_enabled = False
|
||||||
|
self._ctx_cache = {}
|
||||||
|
|
||||||
def forward(self,
|
def forward(self,
|
||||||
x: Tensor,
|
x: Tensor,
|
||||||
x_action: Tensor,
|
x_action: Tensor,
|
||||||
@@ -720,6 +724,10 @@ class WMAModel(nn.Module):
|
|||||||
repeat_only=False).type(x.dtype)
|
repeat_only=False).type(x.dtype)
|
||||||
emb = self.time_embed(t_emb)
|
emb = self.time_embed(t_emb)
|
||||||
|
|
||||||
|
_ctx_key = context.data_ptr()
|
||||||
|
if self._ctx_cache_enabled and _ctx_key in self._ctx_cache:
|
||||||
|
context = self._ctx_cache[_ctx_key]
|
||||||
|
else:
|
||||||
bt, l_context, _ = context.shape
|
bt, l_context, _ = context.shape
|
||||||
if self.base_model_gen_only:
|
if self.base_model_gen_only:
|
||||||
assert l_context == 77 + self.n_obs_steps * 16, ">>> ERROR Context dim 1 ..." ## NOTE HANDCODE
|
assert l_context == 77 + self.n_obs_steps * 16, ">>> ERROR Context dim 1 ..." ## NOTE HANDCODE
|
||||||
@@ -772,6 +780,8 @@ class WMAModel(nn.Module):
|
|||||||
context_img
|
context_img
|
||||||
],
|
],
|
||||||
dim=1)
|
dim=1)
|
||||||
|
if self._ctx_cache_enabled:
|
||||||
|
self._ctx_cache[_ctx_key] = context
|
||||||
|
|
||||||
emb = emb.repeat_interleave(repeats=t, dim=0)
|
emb = emb.repeat_interleave(repeats=t, dim=0)
|
||||||
|
|
||||||
@@ -846,3 +856,30 @@ class WMAModel(nn.Module):
|
|||||||
s_y = torch.zeros_like(x_state)
|
s_y = torch.zeros_like(x_state)
|
||||||
|
|
||||||
return y, a_y, s_y
|
return y, a_y, s_y
|
||||||
|
|
||||||
|
|
||||||
|
def enable_ctx_cache(model):
|
||||||
|
"""Enable context precomputation cache on WMAModel and its action/state UNets."""
|
||||||
|
for m in model.modules():
|
||||||
|
if isinstance(m, WMAModel):
|
||||||
|
m._ctx_cache_enabled = True
|
||||||
|
m._ctx_cache = {}
|
||||||
|
# conditional_unet1d cache
|
||||||
|
from unifolm_wma.models.diffusion_head.conditional_unet1d import ConditionalUnet1D
|
||||||
|
for m in model.modules():
|
||||||
|
if isinstance(m, ConditionalUnet1D):
|
||||||
|
m._global_cond_cache_enabled = True
|
||||||
|
m._global_cond_cache = {}
|
||||||
|
|
||||||
|
|
||||||
|
def disable_ctx_cache(model):
|
||||||
|
"""Disable and clear context precomputation cache."""
|
||||||
|
for m in model.modules():
|
||||||
|
if isinstance(m, WMAModel):
|
||||||
|
m._ctx_cache_enabled = False
|
||||||
|
m._ctx_cache = {}
|
||||||
|
from unifolm_wma.models.diffusion_head.conditional_unet1d import ConditionalUnet1D
|
||||||
|
for m in model.modules():
|
||||||
|
if isinstance(m, ConditionalUnet1D):
|
||||||
|
m._global_cond_cache_enabled = False
|
||||||
|
m._global_cond_cache = {}
|
||||||
|
|||||||
@@ -1,10 +1,10 @@
|
|||||||
2026-02-10 17:25:35.484333: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
2026-02-10 17:39:22.590654: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
||||||
2026-02-10 17:25:35.533963: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
2026-02-10 17:39:22.640645: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
||||||
2026-02-10 17:25:35.534009: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
2026-02-10 17:39:22.640689: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
||||||
2026-02-10 17:25:35.535311: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
2026-02-10 17:39:22.642010: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
||||||
2026-02-10 17:25:35.542814: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
2026-02-10 17:39:22.649530: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
||||||
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
||||||
2026-02-10 17:25:36.471650: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
2026-02-10 17:39:23.575804: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
||||||
Global seed set to 123
|
Global seed set to 123
|
||||||
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
|
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
|
||||||
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
||||||
@@ -92,7 +92,7 @@ DEBUG:PIL.Image:Importing WmfImagePlugin
|
|||||||
DEBUG:PIL.Image:Importing WmfImagePlugin
|
DEBUG:PIL.Image:Importing WmfImagePlugin
|
||||||
DEBUG:PIL.Image:Importing XbmImagePlugin
|
DEBUG:PIL.Image:Importing XbmImagePlugin
|
||||||
DEBUG:PIL.Image:Importing XpmImagePlugin
|
DEBUG:PIL.Image:Importing XpmImagePlugin
|
||||||
DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
||||||
|
|
||||||
9%|▉ | 1/11 [00:35<05:55, 35.52s/it]
|
9%|▉ | 1/11 [00:35<05:55, 35.52s/it]
|
||||||
18%|█▊ | 2/11 [01:11<05:21, 35.73s/it]
|
18%|█▊ | 2/11 [01:11<05:21, 35.73s/it]
|
||||||
@@ -125,6 +125,6 @@ DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
|||||||
>>> Step 6: generating actions ...
|
>>> Step 6: generating actions ...
|
||||||
>>> Step 6: interacting with world model ...
|
>>> Step 6: interacting with world model ...
|
||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
>>> Step 7: generating actions ...
|
>>> Step 7: generating actions ...
|
||||||
>>> Step 7: interacting with world model ...
|
>>> Step 7: interacting with world model ...
|
||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
|||||||
Reference in New Issue
Block a user