复用 DDIMSampler + make_schedule微弱提升

This commit is contained in:
2026-02-09 18:26:39 +00:00
parent 6dca3696d8
commit 0b3b0e534a
3 changed files with 21 additions and 13 deletions

View File

@@ -1803,7 +1803,9 @@ class LatentDiffusion(DDPM):
"""
if ddim:
ddim_sampler = DDIMSampler(self)
if not hasattr(self, '_ddim_sampler') or self._ddim_sampler is None:
self._ddim_sampler = DDIMSampler(self)
ddim_sampler = self._ddim_sampler
shape = (self.channels, self.temporal_length, *self.image_size)
samples, actions, states, intermediates = ddim_sampler.sample(
ddim_steps, batch_size, shape, cond, verbose=False, **kwargs)

View File

@@ -18,6 +18,7 @@ class DDIMSampler(object):
self.ddpm_num_timesteps = model.num_timesteps
self.schedule = schedule
self.counter = 0
self._schedule_key = None # (ddim_num_steps, ddim_discretize, ddim_eta)
def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
@@ -30,6 +31,11 @@ class DDIMSampler(object):
ddim_discretize="uniform",
ddim_eta=0.,
verbose=True):
key = (ddim_num_steps, ddim_discretize, ddim_eta)
if self._schedule_key == key:
return
self._schedule_key = key
self.ddim_timesteps = make_ddim_timesteps(
ddim_discr_method=ddim_discretize,
num_ddim_timesteps=ddim_num_steps,