DDIM loop 内小张量分配优化,attention mask 缓存到 GPU,加速30s左右

This commit is contained in:
2026-01-18 22:37:55 +08:00
parent a90efc6718
commit cb334f308b
9 changed files with 103 additions and 49 deletions

View File

@@ -13,7 +13,7 @@ import time
import json import json
from contextlib import contextmanager, nullcontext from contextlib import contextmanager, nullcontext
from dataclasses import dataclass, field, asdict from dataclasses import dataclass, field, asdict
from typing import Optional, Dict, List, Any from typing import Optional, Dict, List, Any, Mapping
from pytorch_lightning import seed_everything from pytorch_lightning import seed_everything
from omegaconf import OmegaConf from omegaconf import OmegaConf
@@ -718,6 +718,17 @@ def preprocess_observation(
return return_observations return return_observations
def _move_to_device(batch: Mapping[str, Any],
device: torch.device) -> dict[str, Any]:
moved = {}
for key, value in batch.items():
if isinstance(value, torch.Tensor) and value.device != device:
moved[key] = value.to(device, non_blocking=True)
else:
moved[key] = value
return moved
def image_guided_synthesis_sim_mode( def image_guided_synthesis_sim_mode(
model: torch.nn.Module, model: torch.nn.Module,
prompts: list[str], prompts: list[str],
@@ -768,7 +779,10 @@ def image_guided_synthesis_sim_mode(
profiler = get_profiler() profiler = get_profiler()
b, _, t, _, _ = noise_shape b, _, t, _, _ = noise_shape
ddim_sampler = getattr(model, "_ddim_sampler", None)
if ddim_sampler is None:
ddim_sampler = DDIMSampler(model) ddim_sampler = DDIMSampler(model)
model._ddim_sampler = ddim_sampler
batch_size = noise_shape[0] batch_size = noise_shape[0]
fs = torch.tensor([fs] * batch_size, dtype=torch.long, device=model.device) fs = torch.tensor([fs] * batch_size, dtype=torch.long, device=model.device)
@@ -1077,10 +1091,7 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
'action': 'action':
torch.zeros_like(batch['action'][-1]).unsqueeze(0) torch.zeros_like(batch['action'][-1]).unsqueeze(0)
} }
observation = { observation = _move_to_device(observation, device)
key: observation[key].to(device, non_blocking=True)
for key in observation
}
# Update observation queues # Update observation queues
cond_obs_queues = populate_queues(cond_obs_queues, observation) cond_obs_queues = populate_queues(cond_obs_queues, observation)
@@ -1094,6 +1105,8 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
# Multi-round interaction with the world-model # Multi-round interaction with the world-model
with pytorch_prof_ctx: with pytorch_prof_ctx:
for itr in tqdm(range(args.n_iter)): for itr in tqdm(range(args.n_iter)):
log_every = max(1, args.step_log_every)
log_step = (itr % log_every == 0)
profiler.current_iteration = itr profiler.current_iteration = itr
profiler.record_memory(f"iter_{itr}_start") profiler.record_memory(f"iter_{itr}_start")
@@ -1111,12 +1124,10 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
'action': 'action':
torch.stack(list(cond_obs_queues['action']), dim=1), torch.stack(list(cond_obs_queues['action']), dim=1),
} }
observation = { observation = _move_to_device(observation, device)
key: observation[key].to(device, non_blocking=True)
for key in observation
}
# Use world-model in policy to generate action # Use world-model in policy to generate action
if log_step:
print(f'>>> Step {itr}: generating actions ...') print(f'>>> Step {itr}: generating actions ...')
with profiler.profile_section("action_generation"): with profiler.profile_section("action_generation"):
pred_videos_0, pred_actions, _ = image_guided_synthesis_sim_mode( pred_videos_0, pred_actions, _ = image_guided_synthesis_sim_mode(
@@ -1156,12 +1167,10 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
'action': 'action':
torch.stack(list(cond_obs_queues['action']), dim=1), torch.stack(list(cond_obs_queues['action']), dim=1),
} }
observation = { observation = _move_to_device(observation, device)
key: observation[key].to(device, non_blocking=True)
for key in observation
}
# Interaction with the world-model # Interaction with the world-model
if log_step:
print(f'>>> Step {itr}: interacting with world model ...') print(f'>>> Step {itr}: interacting with world model ...')
with profiler.profile_section("world_model_interaction"): with profiler.profile_section("world_model_interaction"):
pred_videos_1, _, pred_states = image_guided_synthesis_sim_mode( pred_videos_1, _, pred_states = image_guided_synthesis_sim_mode(
@@ -1364,6 +1373,12 @@ def get_parser():
default="fp32", default="fp32",
help="Dtype for VAE/first_stage_model weights and forward autocast." help="Dtype for VAE/first_stage_model weights and forward autocast."
) )
parser.add_argument(
"--step_log_every",
type=int,
default=1,
help="Print per-iteration step logs every N iterations."
)
parser.add_argument( parser.add_argument(
"--n_action_steps", "--n_action_steps",
type=int, type=int,

View File

@@ -28,6 +28,11 @@ class DDIMSampler(object):
ddim_discretize="uniform", ddim_discretize="uniform",
ddim_eta=0., ddim_eta=0.,
verbose=True): verbose=True):
device = self.model.betas.device
cache_key = (ddim_num_steps, ddim_discretize, float(ddim_eta),
str(device))
if getattr(self, "_schedule_cache", None) == cache_key:
return
self.ddim_timesteps = make_ddim_timesteps( self.ddim_timesteps = make_ddim_timesteps(
ddim_discr_method=ddim_discretize, ddim_discr_method=ddim_discretize,
num_ddim_timesteps=ddim_num_steps, num_ddim_timesteps=ddim_num_steps,
@@ -67,16 +72,26 @@ class DDIMSampler(object):
ddim_timesteps=self.ddim_timesteps, ddim_timesteps=self.ddim_timesteps,
eta=ddim_eta, eta=ddim_eta,
verbose=verbose) verbose=verbose)
ddim_sigmas = torch.as_tensor(ddim_sigmas,
device=self.model.device,
dtype=torch.float32)
ddim_alphas = torch.as_tensor(ddim_alphas,
device=self.model.device,
dtype=torch.float32)
ddim_alphas_prev = torch.as_tensor(ddim_alphas_prev,
device=self.model.device,
dtype=torch.float32)
self.register_buffer('ddim_sigmas', ddim_sigmas) self.register_buffer('ddim_sigmas', ddim_sigmas)
self.register_buffer('ddim_alphas', ddim_alphas) self.register_buffer('ddim_alphas', ddim_alphas)
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
self.register_buffer('ddim_sqrt_one_minus_alphas', self.register_buffer('ddim_sqrt_one_minus_alphas',
np.sqrt(1. - ddim_alphas)) torch.sqrt(1. - ddim_alphas))
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) *
(1 - self.alphas_cumprod / self.alphas_cumprod_prev)) (1 - self.alphas_cumprod / self.alphas_cumprod_prev))
self.register_buffer('ddim_sigmas_for_original_num_steps', self.register_buffer('ddim_sigmas_for_original_num_steps',
sigmas_for_original_sampling_steps) sigmas_for_original_sampling_steps)
self._schedule_cache = cache_key
@torch.no_grad() @torch.no_grad()
def sample( def sample(
@@ -228,10 +243,14 @@ class DDIMSampler(object):
'x_inter_state': [state], 'x_inter_state': [state],
'pred_x0_state': [state], 'pred_x0_state': [state],
} }
time_range = reversed(range( if ddim_use_original_steps:
0, timesteps)) if ddim_use_original_steps else np.flip(timesteps) time_range = np.arange(timesteps - 1, -1, -1)
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[ else:
0] time_range = np.flip(timesteps)
time_range = np.ascontiguousarray(time_range)
total_steps = int(time_range.shape[0])
t_seq = torch.as_tensor(time_range, device=device, dtype=torch.long)
ts_batch = t_seq.unsqueeze(1).expand(total_steps, b).contiguous()
if verbose: if verbose:
iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps) iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
else: else:
@@ -243,7 +262,7 @@ class DDIMSampler(object):
dp_ddim_scheduler_state.set_timesteps(len(timesteps)) dp_ddim_scheduler_state.set_timesteps(len(timesteps))
for i, step in enumerate(iterator): for i, step in enumerate(iterator):
index = total_steps - i - 1 index = total_steps - i - 1
ts = torch.full((b, ), step, device=device, dtype=torch.long) ts = ts_batch[i]
# Use mask to blend noised original latent (img_orig) & new sampled latent (img) # Use mask to blend noised original latent (img_orig) & new sampled latent (img)
if mask is not None: if mask is not None:
@@ -378,16 +397,14 @@ class DDIMSampler(object):
sigmas = self.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas sigmas = self.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
if is_video: if is_video:
size = (b, 1, 1, 1, 1) size = (1, 1, 1, 1, 1)
else: else:
size = (b, 1, 1, 1) size = (1, 1, 1, 1)
a_t = torch.full(size, alphas[index], device=device) a_t = alphas[index].view(size)
a_prev = torch.full(size, alphas_prev[index], device=device) a_prev = alphas_prev[index].view(size)
sigma_t = torch.full(size, sigmas[index], device=device) sigma_t = sigmas[index].view(size)
sqrt_one_minus_at = torch.full(size, sqrt_one_minus_at = sqrt_one_minus_alphas[index].view(size)
sqrt_one_minus_alphas[index],
device=device)
if self.model.parameterization != "v": if self.model.parameterization != "v":
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
@@ -395,12 +412,8 @@ class DDIMSampler(object):
pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output) pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
if self.model.use_dynamic_rescale: if self.model.use_dynamic_rescale:
scale_t = torch.full(size, scale_t = self.ddim_scale_arr[index].view(size)
self.ddim_scale_arr[index], prev_scale_t = self.ddim_scale_arr_prev[index].view(size)
device=device)
prev_scale_t = torch.full(size,
self.ddim_scale_arr_prev[index],
device=device)
rescale = (prev_scale_t / scale_t) rescale = (prev_scale_t / scale_t)
pred_x0 *= rescale pred_x0 *= rescale

View File

@@ -99,6 +99,7 @@ class CrossAttention(nn.Module):
self.agent_state_context_len = agent_state_context_len self.agent_state_context_len = agent_state_context_len
self.agent_action_context_len = agent_action_context_len self.agent_action_context_len = agent_action_context_len
self.cross_attention_scale_learnable = cross_attention_scale_learnable self.cross_attention_scale_learnable = cross_attention_scale_learnable
self._attn_mask_cache = {}
if self.image_cross_attention: if self.image_cross_attention:
self.to_k_ip = nn.Linear(context_dim, inner_dim, bias=False) self.to_k_ip = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v_ip = nn.Linear(context_dim, inner_dim, bias=False) self.to_v_ip = nn.Linear(context_dim, inner_dim, bias=False)
@@ -275,7 +276,8 @@ class CrossAttention(nn.Module):
attn_mask_aa = self._get_attn_mask_aa(x.shape[0], attn_mask_aa = self._get_attn_mask_aa(x.shape[0],
q.shape[1], q.shape[1],
k_aa.shape[1], k_aa.shape[1],
block_size=16).to(k_aa.device) block_size=16,
device=k_aa.device)
else: else:
if not spatial_self_attn: if not spatial_self_attn:
assert 1 > 2, ">>> ERROR: you should never go into here ..." assert 1 > 2, ">>> ERROR: you should never go into here ..."
@@ -386,14 +388,26 @@ class CrossAttention(nn.Module):
return self.to_out(out) return self.to_out(out)
def _get_attn_mask_aa(self, b, l1, l2, block_size=16): def _get_attn_mask_aa(self,
b,
l1,
l2,
block_size=16,
device=None):
if device is None:
device = self.to_q.weight.device
cache_key = (b, l1, l2, block_size, str(device))
if cache_key in self._attn_mask_cache:
return self._attn_mask_cache[cache_key]
num_token = l2 // block_size num_token = l2 // block_size
start_positions = ((torch.arange(b) % block_size) + 1) * num_token start_positions = ((torch.arange(b, device=device) % block_size) +
col_indices = torch.arange(l2) 1) * num_token
col_indices = torch.arange(l2, device=device)
mask_2d = col_indices.unsqueeze(0) >= start_positions.unsqueeze(1) mask_2d = col_indices.unsqueeze(0) >= start_positions.unsqueeze(1)
mask = mask_2d.unsqueeze(1).expand(b, l1, l2) mask = mask_2d.unsqueeze(1).expand(b, l1, l2)
attn_mask = torch.zeros_like(mask, dtype=torch.float) attn_mask = torch.zeros_like(mask, dtype=torch.float32)
attn_mask[mask] = float('-inf') attn_mask[mask] = float('-inf')
self._attn_mask_cache[cache_key] = attn_mask
return attn_mask return attn_mask

View File

@@ -107,3 +107,15 @@ embedder
1. 新增 --encoder_mode {fp32, autocast, bf16_full} 1. 新增 --encoder_mode {fp32, autocast, bf16_full}
2. bf16_full = 权重 BF16 + 前向 BF16 2. bf16_full = 权重 BF16 + 前向 BF16
3. autocast = 权重 FP32 + 仅主干 autocast现在的实现 3. autocast = 权重 FP32 + 仅主干 autocast现在的实现
1. DDIM loop 内小张量分配优化(已完成)
- 每步 torch.full(...) 改成预先构造/广播,减少 loop 内分配
- 位置src/unifolm_wma/models/samplers/ddim.py
2. attention mask 缓存到 GPU已完成
- _get_attn_mask_aa 现在直接在目标 device 构造并缓存,避免每步 CPU→GPU 拷贝
- 位置src/unifolm_wma/modules/attention.py