实现了Context 预计算和缓存功能,提升了采样效率。 psnr不下降

This commit is contained in:
2026-02-09 17:42:47 +00:00
parent f192c8aca9
commit 6dca3696d8
4 changed files with 119 additions and 69 deletions

View File

@@ -7,6 +7,7 @@ from unifolm_wma.utils.common import noise_like
from unifolm_wma.utils.common import extract_into_tensor
from tqdm import tqdm
from unifolm_wma.modules.attention import enable_cross_attn_kv_cache, disable_cross_attn_kv_cache
from unifolm_wma.modules.networks.wma_model import enable_ctx_cache, disable_ctx_cache
class DDIMSampler(object):
@@ -245,6 +246,7 @@ class DDIMSampler(object):
dp_ddim_scheduler_state.set_timesteps(len(timesteps))
ts = torch.empty((b, ), device=device, dtype=torch.long)
enable_cross_attn_kv_cache(self.model)
enable_ctx_cache(self.model)
try:
for i, step in enumerate(iterator):
index = total_steps - i - 1
@@ -305,6 +307,7 @@ class DDIMSampler(object):
intermediates['x_inter_state'].append(state)
finally:
disable_cross_attn_kv_cache(self.model)
disable_ctx_cache(self.model)
return img, action, state, intermediates