DDIM loop 内小张量分配优化,attention mask 缓存到 GPU,加速30s左右

This commit is contained in:
2026-01-18 22:37:55 +08:00
parent a90efc6718
commit cb334f308b
9 changed files with 103 additions and 49 deletions

View File

@@ -13,7 +13,7 @@ import time
import json
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass, field, asdict
from typing import Optional, Dict, List, Any
from typing import Optional, Dict, List, Any, Mapping
from pytorch_lightning import seed_everything
from omegaconf import OmegaConf
@@ -673,8 +673,8 @@ def get_latent_z(model, videos: Tensor) -> Tensor:
return z
def preprocess_observation(
model, observations: dict[str, np.ndarray]) -> dict[str, Tensor]:
def preprocess_observation(
model, observations: dict[str, np.ndarray]) -> dict[str, Tensor]:
"""Convert environment observation to LeRobot format observation.
Args:
observation: Dictionary of observation batches from a Gym vector environment.
@@ -715,7 +715,18 @@ def preprocess_observation(
return_observations['observation.state'].to(model.device)
})['observation.state']
return return_observations
return return_observations
def _move_to_device(batch: Mapping[str, Any],
device: torch.device) -> dict[str, Any]:
moved = {}
for key, value in batch.items():
if isinstance(value, torch.Tensor) and value.device != device:
moved[key] = value.to(device, non_blocking=True)
else:
moved[key] = value
return moved
def image_guided_synthesis_sim_mode(
@@ -768,8 +779,11 @@ def image_guided_synthesis_sim_mode(
profiler = get_profiler()
b, _, t, _, _ = noise_shape
ddim_sampler = DDIMSampler(model)
batch_size = noise_shape[0]
ddim_sampler = getattr(model, "_ddim_sampler", None)
if ddim_sampler is None:
ddim_sampler = DDIMSampler(model)
model._ddim_sampler = ddim_sampler
batch_size = noise_shape[0]
fs = torch.tensor([fs] * batch_size, dtype=torch.long, device=model.device)
@@ -900,7 +914,7 @@ def image_guided_synthesis_sim_mode(
return batch_variants, actions, states
def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
"""
Run inference pipeline on prompts and image inputs.
@@ -912,7 +926,7 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
Returns:
None
"""
profiler = get_profiler()
profiler = get_profiler()
# Create inference and tensorboard dirs
os.makedirs(args.savedir + '/inference', exist_ok=True)
@@ -1077,10 +1091,7 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
'action':
torch.zeros_like(batch['action'][-1]).unsqueeze(0)
}
observation = {
key: observation[key].to(device, non_blocking=True)
for key in observation
}
observation = _move_to_device(observation, device)
# Update observation queues
cond_obs_queues = populate_queues(cond_obs_queues, observation)
@@ -1093,7 +1104,9 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
# Multi-round interaction with the world-model
with pytorch_prof_ctx:
for itr in tqdm(range(args.n_iter)):
for itr in tqdm(range(args.n_iter)):
log_every = max(1, args.step_log_every)
log_step = (itr % log_every == 0)
profiler.current_iteration = itr
profiler.record_memory(f"iter_{itr}_start")
@@ -1111,13 +1124,11 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
'action':
torch.stack(list(cond_obs_queues['action']), dim=1),
}
observation = {
key: observation[key].to(device, non_blocking=True)
for key in observation
}
observation = _move_to_device(observation, device)
# Use world-model in policy to generate action
print(f'>>> Step {itr}: generating actions ...')
if log_step:
print(f'>>> Step {itr}: generating actions ...')
with profiler.profile_section("action_generation"):
pred_videos_0, pred_actions, _ = image_guided_synthesis_sim_mode(
model,
@@ -1156,13 +1167,11 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
'action':
torch.stack(list(cond_obs_queues['action']), dim=1),
}
observation = {
key: observation[key].to(device, non_blocking=True)
for key in observation
}
observation = _move_to_device(observation, device)
# Interaction with the world-model
print(f'>>> Step {itr}: interacting with world model ...')
if log_step:
print(f'>>> Step {itr}: interacting with world model ...')
with profiler.profile_section("world_model_interaction"):
pred_videos_1, _, pred_states = image_guided_synthesis_sim_mode(
model,
@@ -1364,6 +1373,12 @@ def get_parser():
default="fp32",
help="Dtype for VAE/first_stage_model weights and forward autocast."
)
parser.add_argument(
"--step_log_every",
type=int,
default=1,
help="Print per-iteration step logs every N iterations."
)
parser.add_argument(
"--n_action_steps",
type=int,

View File

@@ -28,6 +28,11 @@ class DDIMSampler(object):
ddim_discretize="uniform",
ddim_eta=0.,
verbose=True):
device = self.model.betas.device
cache_key = (ddim_num_steps, ddim_discretize, float(ddim_eta),
str(device))
if getattr(self, "_schedule_cache", None) == cache_key:
return
self.ddim_timesteps = make_ddim_timesteps(
ddim_discr_method=ddim_discretize,
num_ddim_timesteps=ddim_num_steps,
@@ -67,16 +72,26 @@ class DDIMSampler(object):
ddim_timesteps=self.ddim_timesteps,
eta=ddim_eta,
verbose=verbose)
ddim_sigmas = torch.as_tensor(ddim_sigmas,
device=self.model.device,
dtype=torch.float32)
ddim_alphas = torch.as_tensor(ddim_alphas,
device=self.model.device,
dtype=torch.float32)
ddim_alphas_prev = torch.as_tensor(ddim_alphas_prev,
device=self.model.device,
dtype=torch.float32)
self.register_buffer('ddim_sigmas', ddim_sigmas)
self.register_buffer('ddim_alphas', ddim_alphas)
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
self.register_buffer('ddim_sqrt_one_minus_alphas',
np.sqrt(1. - ddim_alphas))
torch.sqrt(1. - ddim_alphas))
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) *
(1 - self.alphas_cumprod / self.alphas_cumprod_prev))
self.register_buffer('ddim_sigmas_for_original_num_steps',
sigmas_for_original_sampling_steps)
self._schedule_cache = cache_key
@torch.no_grad()
def sample(
@@ -228,10 +243,14 @@ class DDIMSampler(object):
'x_inter_state': [state],
'pred_x0_state': [state],
}
time_range = reversed(range(
0, timesteps)) if ddim_use_original_steps else np.flip(timesteps)
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[
0]
if ddim_use_original_steps:
time_range = np.arange(timesteps - 1, -1, -1)
else:
time_range = np.flip(timesteps)
time_range = np.ascontiguousarray(time_range)
total_steps = int(time_range.shape[0])
t_seq = torch.as_tensor(time_range, device=device, dtype=torch.long)
ts_batch = t_seq.unsqueeze(1).expand(total_steps, b).contiguous()
if verbose:
iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
else:
@@ -243,7 +262,7 @@ class DDIMSampler(object):
dp_ddim_scheduler_state.set_timesteps(len(timesteps))
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full((b, ), step, device=device, dtype=torch.long)
ts = ts_batch[i]
# Use mask to blend noised original latent (img_orig) & new sampled latent (img)
if mask is not None:
@@ -378,16 +397,14 @@ class DDIMSampler(object):
sigmas = self.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
if is_video:
size = (b, 1, 1, 1, 1)
size = (1, 1, 1, 1, 1)
else:
size = (b, 1, 1, 1)
size = (1, 1, 1, 1)
a_t = torch.full(size, alphas[index], device=device)
a_prev = torch.full(size, alphas_prev[index], device=device)
sigma_t = torch.full(size, sigmas[index], device=device)
sqrt_one_minus_at = torch.full(size,
sqrt_one_minus_alphas[index],
device=device)
a_t = alphas[index].view(size)
a_prev = alphas_prev[index].view(size)
sigma_t = sigmas[index].view(size)
sqrt_one_minus_at = sqrt_one_minus_alphas[index].view(size)
if self.model.parameterization != "v":
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
@@ -395,12 +412,8 @@ class DDIMSampler(object):
pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
if self.model.use_dynamic_rescale:
scale_t = torch.full(size,
self.ddim_scale_arr[index],
device=device)
prev_scale_t = torch.full(size,
self.ddim_scale_arr_prev[index],
device=device)
scale_t = self.ddim_scale_arr[index].view(size)
prev_scale_t = self.ddim_scale_arr_prev[index].view(size)
rescale = (prev_scale_t / scale_t)
pred_x0 *= rescale

View File

@@ -99,6 +99,7 @@ class CrossAttention(nn.Module):
self.agent_state_context_len = agent_state_context_len
self.agent_action_context_len = agent_action_context_len
self.cross_attention_scale_learnable = cross_attention_scale_learnable
self._attn_mask_cache = {}
if self.image_cross_attention:
self.to_k_ip = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v_ip = nn.Linear(context_dim, inner_dim, bias=False)
@@ -275,7 +276,8 @@ class CrossAttention(nn.Module):
attn_mask_aa = self._get_attn_mask_aa(x.shape[0],
q.shape[1],
k_aa.shape[1],
block_size=16).to(k_aa.device)
block_size=16,
device=k_aa.device)
else:
if not spatial_self_attn:
assert 1 > 2, ">>> ERROR: you should never go into here ..."
@@ -386,14 +388,26 @@ class CrossAttention(nn.Module):
return self.to_out(out)
def _get_attn_mask_aa(self, b, l1, l2, block_size=16):
def _get_attn_mask_aa(self,
b,
l1,
l2,
block_size=16,
device=None):
if device is None:
device = self.to_q.weight.device
cache_key = (b, l1, l2, block_size, str(device))
if cache_key in self._attn_mask_cache:
return self._attn_mask_cache[cache_key]
num_token = l2 // block_size
start_positions = ((torch.arange(b) % block_size) + 1) * num_token
col_indices = torch.arange(l2)
start_positions = ((torch.arange(b, device=device) % block_size) +
1) * num_token
col_indices = torch.arange(l2, device=device)
mask_2d = col_indices.unsqueeze(0) >= start_positions.unsqueeze(1)
mask = mask_2d.unsqueeze(1).expand(b, l1, l2)
attn_mask = torch.zeros_like(mask, dtype=torch.float)
attn_mask = torch.zeros_like(mask, dtype=torch.float32)
attn_mask[mask] = float('-inf')
self._attn_mask_cache[cache_key] = attn_mask
return attn_mask

View File

@@ -106,4 +106,16 @@ embedder
1. 新增 --encoder_mode {fp32, autocast, bf16_full}
2. bf16_full = 权重 BF16 + 前向 BF16
3. autocast = 权重 FP32 + 仅主干 autocast现在的实现
3. autocast = 权重 FP32 + 仅主干 autocast现在的实现
1. DDIM loop 内小张量分配优化(已完成)
- 每步 torch.full(...) 改成预先构造/广播,减少 loop 内分配
- 位置src/unifolm_wma/models/samplers/ddim.py
2. attention mask 缓存到 GPU已完成
- _get_attn_mask_aa 现在直接在目标 device 构造并缓存,避免每步 CPU→GPU 拷贝
- 位置src/unifolm_wma/modules/attention.py