实现fs_embed 缓存,收益不明显,精度不降低
This commit is contained in:
@@ -688,6 +688,8 @@ class WMAModel(nn.Module):
|
||||
# Context precomputation cache
|
||||
self._ctx_cache_enabled = False
|
||||
self._ctx_cache = {}
|
||||
# fs_embed cache
|
||||
self._fs_embed_cache = None
|
||||
|
||||
def forward(self,
|
||||
x: Tensor,
|
||||
@@ -789,16 +791,20 @@ class WMAModel(nn.Module):
|
||||
|
||||
# Combine emb
|
||||
if self.fs_condition:
|
||||
if fs is None:
|
||||
fs = torch.tensor([self.default_fs] * b,
|
||||
dtype=torch.long,
|
||||
device=x.device)
|
||||
fs_emb = timestep_embedding(fs,
|
||||
self.model_channels,
|
||||
repeat_only=False).type(x.dtype)
|
||||
|
||||
fs_embed = self.fps_embedding(fs_emb)
|
||||
fs_embed = fs_embed.repeat_interleave(repeats=t, dim=0)
|
||||
if self._ctx_cache_enabled and self._fs_embed_cache is not None:
|
||||
fs_embed = self._fs_embed_cache
|
||||
else:
|
||||
if fs is None:
|
||||
fs = torch.tensor([self.default_fs] * b,
|
||||
dtype=torch.long,
|
||||
device=x.device)
|
||||
fs_emb = timestep_embedding(fs,
|
||||
self.model_channels,
|
||||
repeat_only=False).type(x.dtype)
|
||||
fs_embed = self.fps_embedding(fs_emb)
|
||||
fs_embed = fs_embed.repeat_interleave(repeats=t, dim=0)
|
||||
if self._ctx_cache_enabled:
|
||||
self._fs_embed_cache = fs_embed
|
||||
emb = emb + fs_embed
|
||||
|
||||
h = x.type(self.dtype)
|
||||
@@ -864,6 +870,7 @@ def enable_ctx_cache(model):
|
||||
if isinstance(m, WMAModel):
|
||||
m._ctx_cache_enabled = True
|
||||
m._ctx_cache = {}
|
||||
m._fs_embed_cache = None
|
||||
# conditional_unet1d cache
|
||||
from unifolm_wma.models.diffusion_head.conditional_unet1d import ConditionalUnet1D
|
||||
for m in model.modules():
|
||||
@@ -878,6 +885,7 @@ def disable_ctx_cache(model):
|
||||
if isinstance(m, WMAModel):
|
||||
m._ctx_cache_enabled = False
|
||||
m._ctx_cache = {}
|
||||
m._fs_embed_cache = None
|
||||
from unifolm_wma.models.diffusion_head.conditional_unet1d import ConditionalUnet1D
|
||||
for m in model.modules():
|
||||
if isinstance(m, ConditionalUnet1D):
|
||||
|
||||
Reference in New Issue
Block a user