DDIM loop 内小张量分配优化,attention mask 缓存到 GPU,加速30s左右
This commit is contained in:
@@ -28,6 +28,11 @@ class DDIMSampler(object):
|
||||
ddim_discretize="uniform",
|
||||
ddim_eta=0.,
|
||||
verbose=True):
|
||||
device = self.model.betas.device
|
||||
cache_key = (ddim_num_steps, ddim_discretize, float(ddim_eta),
|
||||
str(device))
|
||||
if getattr(self, "_schedule_cache", None) == cache_key:
|
||||
return
|
||||
self.ddim_timesteps = make_ddim_timesteps(
|
||||
ddim_discr_method=ddim_discretize,
|
||||
num_ddim_timesteps=ddim_num_steps,
|
||||
@@ -67,16 +72,26 @@ class DDIMSampler(object):
|
||||
ddim_timesteps=self.ddim_timesteps,
|
||||
eta=ddim_eta,
|
||||
verbose=verbose)
|
||||
ddim_sigmas = torch.as_tensor(ddim_sigmas,
|
||||
device=self.model.device,
|
||||
dtype=torch.float32)
|
||||
ddim_alphas = torch.as_tensor(ddim_alphas,
|
||||
device=self.model.device,
|
||||
dtype=torch.float32)
|
||||
ddim_alphas_prev = torch.as_tensor(ddim_alphas_prev,
|
||||
device=self.model.device,
|
||||
dtype=torch.float32)
|
||||
self.register_buffer('ddim_sigmas', ddim_sigmas)
|
||||
self.register_buffer('ddim_alphas', ddim_alphas)
|
||||
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
|
||||
self.register_buffer('ddim_sqrt_one_minus_alphas',
|
||||
np.sqrt(1. - ddim_alphas))
|
||||
torch.sqrt(1. - ddim_alphas))
|
||||
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
|
||||
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) *
|
||||
(1 - self.alphas_cumprod / self.alphas_cumprod_prev))
|
||||
self.register_buffer('ddim_sigmas_for_original_num_steps',
|
||||
sigmas_for_original_sampling_steps)
|
||||
self._schedule_cache = cache_key
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(
|
||||
@@ -228,10 +243,14 @@ class DDIMSampler(object):
|
||||
'x_inter_state': [state],
|
||||
'pred_x0_state': [state],
|
||||
}
|
||||
time_range = reversed(range(
|
||||
0, timesteps)) if ddim_use_original_steps else np.flip(timesteps)
|
||||
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[
|
||||
0]
|
||||
if ddim_use_original_steps:
|
||||
time_range = np.arange(timesteps - 1, -1, -1)
|
||||
else:
|
||||
time_range = np.flip(timesteps)
|
||||
time_range = np.ascontiguousarray(time_range)
|
||||
total_steps = int(time_range.shape[0])
|
||||
t_seq = torch.as_tensor(time_range, device=device, dtype=torch.long)
|
||||
ts_batch = t_seq.unsqueeze(1).expand(total_steps, b).contiguous()
|
||||
if verbose:
|
||||
iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
|
||||
else:
|
||||
@@ -243,7 +262,7 @@ class DDIMSampler(object):
|
||||
dp_ddim_scheduler_state.set_timesteps(len(timesteps))
|
||||
for i, step in enumerate(iterator):
|
||||
index = total_steps - i - 1
|
||||
ts = torch.full((b, ), step, device=device, dtype=torch.long)
|
||||
ts = ts_batch[i]
|
||||
|
||||
# Use mask to blend noised original latent (img_orig) & new sampled latent (img)
|
||||
if mask is not None:
|
||||
@@ -378,16 +397,14 @@ class DDIMSampler(object):
|
||||
sigmas = self.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
|
||||
|
||||
if is_video:
|
||||
size = (b, 1, 1, 1, 1)
|
||||
size = (1, 1, 1, 1, 1)
|
||||
else:
|
||||
size = (b, 1, 1, 1)
|
||||
size = (1, 1, 1, 1)
|
||||
|
||||
a_t = torch.full(size, alphas[index], device=device)
|
||||
a_prev = torch.full(size, alphas_prev[index], device=device)
|
||||
sigma_t = torch.full(size, sigmas[index], device=device)
|
||||
sqrt_one_minus_at = torch.full(size,
|
||||
sqrt_one_minus_alphas[index],
|
||||
device=device)
|
||||
a_t = alphas[index].view(size)
|
||||
a_prev = alphas_prev[index].view(size)
|
||||
sigma_t = sigmas[index].view(size)
|
||||
sqrt_one_minus_at = sqrt_one_minus_alphas[index].view(size)
|
||||
|
||||
if self.model.parameterization != "v":
|
||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||
@@ -395,12 +412,8 @@ class DDIMSampler(object):
|
||||
pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
|
||||
|
||||
if self.model.use_dynamic_rescale:
|
||||
scale_t = torch.full(size,
|
||||
self.ddim_scale_arr[index],
|
||||
device=device)
|
||||
prev_scale_t = torch.full(size,
|
||||
self.ddim_scale_arr_prev[index],
|
||||
device=device)
|
||||
scale_t = self.ddim_scale_arr[index].view(size)
|
||||
prev_scale_t = self.ddim_scale_arr_prev[index].view(size)
|
||||
rescale = (prev_scale_t / scale_t)
|
||||
pred_x0 *= rescale
|
||||
|
||||
|
||||
Reference in New Issue
Block a user