DDIM loop 内小张量分配优化,attention mask 缓存到 GPU

This commit is contained in:
qhy
2026-02-10 16:53:00 +08:00
parent ed637c972b
commit 91a9b0febc
5 changed files with 283 additions and 32 deletions

View File

@@ -67,11 +67,12 @@ class DDIMSampler(object):
ddim_timesteps=self.ddim_timesteps,
eta=ddim_eta,
verbose=verbose)
self.register_buffer('ddim_sigmas', ddim_sigmas)
self.register_buffer('ddim_alphas', ddim_alphas)
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
# Ensure tensors are on correct device for efficient indexing
self.register_buffer('ddim_sigmas', to_torch(torch.as_tensor(ddim_sigmas)))
self.register_buffer('ddim_alphas', to_torch(torch.as_tensor(ddim_alphas)))
self.register_buffer('ddim_alphas_prev', to_torch(torch.as_tensor(ddim_alphas_prev)))
self.register_buffer('ddim_sqrt_one_minus_alphas',
np.sqrt(1. - ddim_alphas))
to_torch(torch.as_tensor(np.sqrt(1. - ddim_alphas))))
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) *
(1 - self.alphas_cumprod / self.alphas_cumprod_prev))
@@ -241,9 +242,10 @@ class DDIMSampler(object):
dp_ddim_scheduler_action.set_timesteps(len(timesteps))
dp_ddim_scheduler_state.set_timesteps(len(timesteps))
ts = torch.empty((b, ), device=device, dtype=torch.long)
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full((b, ), step, device=device, dtype=torch.long)
ts.fill_(step)
# Use mask to blend noised original latent (img_orig) & new sampled latent (img)
if mask is not None:
@@ -325,10 +327,6 @@ class DDIMSampler(object):
guidance_rescale=0.0,
**kwargs):
b, *_, device = *x.shape, x.device
if x.dim() == 5:
is_video = True
else:
is_video = False
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
model_output, model_output_action, model_output_state = self.model.apply_model(
@@ -377,17 +375,11 @@ class DDIMSampler(object):
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
sigmas = self.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
if is_video:
size = (b, 1, 1, 1, 1)
else:
size = (b, 1, 1, 1)
a_t = torch.full(size, alphas[index], device=device)
a_prev = torch.full(size, alphas_prev[index], device=device)
sigma_t = torch.full(size, sigmas[index], device=device)
sqrt_one_minus_at = torch.full(size,
sqrt_one_minus_alphas[index],
device=device)
# Use 0-d tensors directly (already on device); broadcasting handles shape
a_t = alphas[index]
a_prev = alphas_prev[index]
sigma_t = sigmas[index]
sqrt_one_minus_at = sqrt_one_minus_alphas[index]
if self.model.parameterization != "v":
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
@@ -395,12 +387,8 @@ class DDIMSampler(object):
pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
if self.model.use_dynamic_rescale:
scale_t = torch.full(size,
self.ddim_scale_arr[index],
device=device)
prev_scale_t = torch.full(size,
self.ddim_scale_arr_prev[index],
device=device)
scale_t = self.ddim_scale_arr[index]
prev_scale_t = self.ddim_scale_arr_prev[index]
rescale = (prev_scale_t / scale_t)
pred_x0 *= rescale