实现了Context 预计算和缓存功能,提升了采样效率。 psnr不下降
This commit is contained in:
@@ -501,6 +501,10 @@ class ConditionalUnet1D(nn.Module):
|
||||
self.last_frame_only = last_frame_only
|
||||
self.horizon = horizon
|
||||
|
||||
# Context precomputation cache
|
||||
self._global_cond_cache_enabled = False
|
||||
self._global_cond_cache = {}
|
||||
|
||||
def forward(self,
|
||||
sample: torch.Tensor,
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
@@ -530,14 +534,20 @@ class ConditionalUnet1D(nn.Module):
|
||||
B, T, D = sample.shape
|
||||
if self.use_linear_act_proj:
|
||||
sample = self.proj_in_action(sample.unsqueeze(-1))
|
||||
global_cond = self.obs_encoder(cond)
|
||||
global_cond = rearrange(global_cond,
|
||||
'(b t) d -> b 1 (t d)',
|
||||
b=B,
|
||||
t=self.n_obs_steps)
|
||||
global_cond = repeat(global_cond,
|
||||
'b c d -> b (repeat c) d',
|
||||
repeat=T)
|
||||
_gc_key = (cond['image'].data_ptr(), cond['agent_pos'].data_ptr())
|
||||
if self._global_cond_cache_enabled and _gc_key in self._global_cond_cache:
|
||||
global_cond = self._global_cond_cache[_gc_key]
|
||||
else:
|
||||
global_cond = self.obs_encoder(cond)
|
||||
global_cond = rearrange(global_cond,
|
||||
'(b t) d -> b 1 (t d)',
|
||||
b=B,
|
||||
t=self.n_obs_steps)
|
||||
global_cond = repeat(global_cond,
|
||||
'b c d -> b (repeat c) d',
|
||||
repeat=T)
|
||||
if self._global_cond_cache_enabled:
|
||||
self._global_cond_cache[_gc_key] = global_cond
|
||||
else:
|
||||
sample = einops.rearrange(sample, 'b h t -> b t h')
|
||||
sample = self.proj_in_horizon(sample)
|
||||
|
||||
Reference in New Issue
Block a user