复用 DDIMSampler + make_schedule微弱提升
This commit is contained in:
@@ -18,6 +18,7 @@ class DDIMSampler(object):
|
||||
self.ddpm_num_timesteps = model.num_timesteps
|
||||
self.schedule = schedule
|
||||
self.counter = 0
|
||||
self._schedule_key = None # (ddim_num_steps, ddim_discretize, ddim_eta)
|
||||
|
||||
def register_buffer(self, name, attr):
|
||||
if type(attr) == torch.Tensor:
|
||||
@@ -30,6 +31,11 @@ class DDIMSampler(object):
|
||||
ddim_discretize="uniform",
|
||||
ddim_eta=0.,
|
||||
verbose=True):
|
||||
key = (ddim_num_steps, ddim_discretize, ddim_eta)
|
||||
if self._schedule_key == key:
|
||||
return
|
||||
self._schedule_key = key
|
||||
|
||||
self.ddim_timesteps = make_ddim_timesteps(
|
||||
ddim_discr_method=ddim_discretize,
|
||||
num_ddim_timesteps=ddim_num_steps,
|
||||
|
||||
Reference in New Issue
Block a user