实现fs_embed 缓存,收益不明显,精度不降低

This commit is contained in:
2026-02-09 18:49:44 +00:00
parent 0b3b0e534a
commit 125b85ce68
2 changed files with 30 additions and 22 deletions

View File

@@ -688,6 +688,8 @@ class WMAModel(nn.Module):
# Context precomputation cache
self._ctx_cache_enabled = False
self._ctx_cache = {}
# fs_embed cache
self._fs_embed_cache = None
def forward(self,
x: Tensor,
@@ -789,16 +791,20 @@ class WMAModel(nn.Module):
# Combine emb
if self.fs_condition:
if fs is None:
fs = torch.tensor([self.default_fs] * b,
dtype=torch.long,
device=x.device)
fs_emb = timestep_embedding(fs,
self.model_channels,
repeat_only=False).type(x.dtype)
fs_embed = self.fps_embedding(fs_emb)
fs_embed = fs_embed.repeat_interleave(repeats=t, dim=0)
if self._ctx_cache_enabled and self._fs_embed_cache is not None:
fs_embed = self._fs_embed_cache
else:
if fs is None:
fs = torch.tensor([self.default_fs] * b,
dtype=torch.long,
device=x.device)
fs_emb = timestep_embedding(fs,
self.model_channels,
repeat_only=False).type(x.dtype)
fs_embed = self.fps_embedding(fs_emb)
fs_embed = fs_embed.repeat_interleave(repeats=t, dim=0)
if self._ctx_cache_enabled:
self._fs_embed_cache = fs_embed
emb = emb + fs_embed
h = x.type(self.dtype)
@@ -864,6 +870,7 @@ def enable_ctx_cache(model):
if isinstance(m, WMAModel):
m._ctx_cache_enabled = True
m._ctx_cache = {}
m._fs_embed_cache = None
# conditional_unet1d cache
from unifolm_wma.models.diffusion_head.conditional_unet1d import ConditionalUnet1D
for m in model.modules():
@@ -878,6 +885,7 @@ def disable_ctx_cache(model):
if isinstance(m, WMAModel):
m._ctx_cache_enabled = False
m._ctx_cache = {}
m._fs_embed_cache = None
from unifolm_wma.models.diffusion_head.conditional_unet1d import ConditionalUnet1D
for m in model.modules():
if isinstance(m, ConditionalUnet1D):