DDIM loop 内小张量分配优化,attention mask 缓存到 GPU,加速30s左右

This commit is contained in:
2026-01-18 22:37:55 +08:00
parent a90efc6718
commit cb334f308b
9 changed files with 103 additions and 49 deletions

View File

@@ -28,6 +28,11 @@ class DDIMSampler(object):
ddim_discretize="uniform",
ddim_eta=0.,
verbose=True):
device = self.model.betas.device
cache_key = (ddim_num_steps, ddim_discretize, float(ddim_eta),
str(device))
if getattr(self, "_schedule_cache", None) == cache_key:
return
self.ddim_timesteps = make_ddim_timesteps(
ddim_discr_method=ddim_discretize,
num_ddim_timesteps=ddim_num_steps,
@@ -67,16 +72,26 @@ class DDIMSampler(object):
ddim_timesteps=self.ddim_timesteps,
eta=ddim_eta,
verbose=verbose)
ddim_sigmas = torch.as_tensor(ddim_sigmas,
device=self.model.device,
dtype=torch.float32)
ddim_alphas = torch.as_tensor(ddim_alphas,
device=self.model.device,
dtype=torch.float32)
ddim_alphas_prev = torch.as_tensor(ddim_alphas_prev,
device=self.model.device,
dtype=torch.float32)
self.register_buffer('ddim_sigmas', ddim_sigmas)
self.register_buffer('ddim_alphas', ddim_alphas)
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
self.register_buffer('ddim_sqrt_one_minus_alphas',
np.sqrt(1. - ddim_alphas))
torch.sqrt(1. - ddim_alphas))
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) *
(1 - self.alphas_cumprod / self.alphas_cumprod_prev))
self.register_buffer('ddim_sigmas_for_original_num_steps',
sigmas_for_original_sampling_steps)
self._schedule_cache = cache_key
@torch.no_grad()
def sample(
@@ -228,10 +243,14 @@ class DDIMSampler(object):
'x_inter_state': [state],
'pred_x0_state': [state],
}
time_range = reversed(range(
0, timesteps)) if ddim_use_original_steps else np.flip(timesteps)
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[
0]
if ddim_use_original_steps:
time_range = np.arange(timesteps - 1, -1, -1)
else:
time_range = np.flip(timesteps)
time_range = np.ascontiguousarray(time_range)
total_steps = int(time_range.shape[0])
t_seq = torch.as_tensor(time_range, device=device, dtype=torch.long)
ts_batch = t_seq.unsqueeze(1).expand(total_steps, b).contiguous()
if verbose:
iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
else:
@@ -243,7 +262,7 @@ class DDIMSampler(object):
dp_ddim_scheduler_state.set_timesteps(len(timesteps))
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full((b, ), step, device=device, dtype=torch.long)
ts = ts_batch[i]
# Use mask to blend noised original latent (img_orig) & new sampled latent (img)
if mask is not None:
@@ -378,16 +397,14 @@ class DDIMSampler(object):
sigmas = self.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
if is_video:
size = (b, 1, 1, 1, 1)
size = (1, 1, 1, 1, 1)
else:
size = (b, 1, 1, 1)
size = (1, 1, 1, 1)
a_t = torch.full(size, alphas[index], device=device)
a_prev = torch.full(size, alphas_prev[index], device=device)
sigma_t = torch.full(size, sigmas[index], device=device)
sqrt_one_minus_at = torch.full(size,
sqrt_one_minus_alphas[index],
device=device)
a_t = alphas[index].view(size)
a_prev = alphas_prev[index].view(size)
sigma_t = sigmas[index].view(size)
sqrt_one_minus_at = sqrt_one_minus_alphas[index].view(size)
if self.model.parameterization != "v":
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
@@ -395,12 +412,8 @@ class DDIMSampler(object):
pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
if self.model.use_dynamic_rescale:
scale_t = torch.full(size,
self.ddim_scale_arr[index],
device=device)
prev_scale_t = torch.full(size,
self.ddim_scale_arr_prev[index],
device=device)
scale_t = self.ddim_scale_arr[index].view(size)
prev_scale_t = self.ddim_scale_arr_prev[index].view(size)
rescale = (prev_scale_t / scale_t)
pred_x0 *= rescale

View File

@@ -99,6 +99,7 @@ class CrossAttention(nn.Module):
self.agent_state_context_len = agent_state_context_len
self.agent_action_context_len = agent_action_context_len
self.cross_attention_scale_learnable = cross_attention_scale_learnable
self._attn_mask_cache = {}
if self.image_cross_attention:
self.to_k_ip = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v_ip = nn.Linear(context_dim, inner_dim, bias=False)
@@ -275,7 +276,8 @@ class CrossAttention(nn.Module):
attn_mask_aa = self._get_attn_mask_aa(x.shape[0],
q.shape[1],
k_aa.shape[1],
block_size=16).to(k_aa.device)
block_size=16,
device=k_aa.device)
else:
if not spatial_self_attn:
assert 1 > 2, ">>> ERROR: you should never go into here ..."
@@ -386,14 +388,26 @@ class CrossAttention(nn.Module):
return self.to_out(out)
def _get_attn_mask_aa(self, b, l1, l2, block_size=16):
def _get_attn_mask_aa(self,
b,
l1,
l2,
block_size=16,
device=None):
if device is None:
device = self.to_q.weight.device
cache_key = (b, l1, l2, block_size, str(device))
if cache_key in self._attn_mask_cache:
return self._attn_mask_cache[cache_key]
num_token = l2 // block_size
start_positions = ((torch.arange(b) % block_size) + 1) * num_token
col_indices = torch.arange(l2)
start_positions = ((torch.arange(b, device=device) % block_size) +
1) * num_token
col_indices = torch.arange(l2, device=device)
mask_2d = col_indices.unsqueeze(0) >= start_positions.unsqueeze(1)
mask = mask_2d.unsqueeze(1).expand(b, l1, l2)
attn_mask = torch.zeros_like(mask, dtype=torch.float)
attn_mask = torch.zeros_like(mask, dtype=torch.float32)
attn_mask[mask] = float('-inf')
self._attn_mask_cache[cache_key] = attn_mask
return attn_mask