DDIM loop 内小张量分配优化,attention mask 缓存到 GPU,加速30s左右
This commit is contained in:
@@ -13,7 +13,7 @@ import time
|
|||||||
import json
|
import json
|
||||||
from contextlib import contextmanager, nullcontext
|
from contextlib import contextmanager, nullcontext
|
||||||
from dataclasses import dataclass, field, asdict
|
from dataclasses import dataclass, field, asdict
|
||||||
from typing import Optional, Dict, List, Any
|
from typing import Optional, Dict, List, Any, Mapping
|
||||||
|
|
||||||
from pytorch_lightning import seed_everything
|
from pytorch_lightning import seed_everything
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
@@ -718,6 +718,17 @@ def preprocess_observation(
|
|||||||
return return_observations
|
return return_observations
|
||||||
|
|
||||||
|
|
||||||
|
def _move_to_device(batch: Mapping[str, Any],
|
||||||
|
device: torch.device) -> dict[str, Any]:
|
||||||
|
moved = {}
|
||||||
|
for key, value in batch.items():
|
||||||
|
if isinstance(value, torch.Tensor) and value.device != device:
|
||||||
|
moved[key] = value.to(device, non_blocking=True)
|
||||||
|
else:
|
||||||
|
moved[key] = value
|
||||||
|
return moved
|
||||||
|
|
||||||
|
|
||||||
def image_guided_synthesis_sim_mode(
|
def image_guided_synthesis_sim_mode(
|
||||||
model: torch.nn.Module,
|
model: torch.nn.Module,
|
||||||
prompts: list[str],
|
prompts: list[str],
|
||||||
@@ -768,7 +779,10 @@ def image_guided_synthesis_sim_mode(
|
|||||||
profiler = get_profiler()
|
profiler = get_profiler()
|
||||||
|
|
||||||
b, _, t, _, _ = noise_shape
|
b, _, t, _, _ = noise_shape
|
||||||
|
ddim_sampler = getattr(model, "_ddim_sampler", None)
|
||||||
|
if ddim_sampler is None:
|
||||||
ddim_sampler = DDIMSampler(model)
|
ddim_sampler = DDIMSampler(model)
|
||||||
|
model._ddim_sampler = ddim_sampler
|
||||||
batch_size = noise_shape[0]
|
batch_size = noise_shape[0]
|
||||||
|
|
||||||
fs = torch.tensor([fs] * batch_size, dtype=torch.long, device=model.device)
|
fs = torch.tensor([fs] * batch_size, dtype=torch.long, device=model.device)
|
||||||
@@ -1077,10 +1091,7 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
|||||||
'action':
|
'action':
|
||||||
torch.zeros_like(batch['action'][-1]).unsqueeze(0)
|
torch.zeros_like(batch['action'][-1]).unsqueeze(0)
|
||||||
}
|
}
|
||||||
observation = {
|
observation = _move_to_device(observation, device)
|
||||||
key: observation[key].to(device, non_blocking=True)
|
|
||||||
for key in observation
|
|
||||||
}
|
|
||||||
# Update observation queues
|
# Update observation queues
|
||||||
cond_obs_queues = populate_queues(cond_obs_queues, observation)
|
cond_obs_queues = populate_queues(cond_obs_queues, observation)
|
||||||
|
|
||||||
@@ -1094,6 +1105,8 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
|||||||
# Multi-round interaction with the world-model
|
# Multi-round interaction with the world-model
|
||||||
with pytorch_prof_ctx:
|
with pytorch_prof_ctx:
|
||||||
for itr in tqdm(range(args.n_iter)):
|
for itr in tqdm(range(args.n_iter)):
|
||||||
|
log_every = max(1, args.step_log_every)
|
||||||
|
log_step = (itr % log_every == 0)
|
||||||
profiler.current_iteration = itr
|
profiler.current_iteration = itr
|
||||||
profiler.record_memory(f"iter_{itr}_start")
|
profiler.record_memory(f"iter_{itr}_start")
|
||||||
|
|
||||||
@@ -1111,12 +1124,10 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
|||||||
'action':
|
'action':
|
||||||
torch.stack(list(cond_obs_queues['action']), dim=1),
|
torch.stack(list(cond_obs_queues['action']), dim=1),
|
||||||
}
|
}
|
||||||
observation = {
|
observation = _move_to_device(observation, device)
|
||||||
key: observation[key].to(device, non_blocking=True)
|
|
||||||
for key in observation
|
|
||||||
}
|
|
||||||
|
|
||||||
# Use world-model in policy to generate action
|
# Use world-model in policy to generate action
|
||||||
|
if log_step:
|
||||||
print(f'>>> Step {itr}: generating actions ...')
|
print(f'>>> Step {itr}: generating actions ...')
|
||||||
with profiler.profile_section("action_generation"):
|
with profiler.profile_section("action_generation"):
|
||||||
pred_videos_0, pred_actions, _ = image_guided_synthesis_sim_mode(
|
pred_videos_0, pred_actions, _ = image_guided_synthesis_sim_mode(
|
||||||
@@ -1156,12 +1167,10 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
|||||||
'action':
|
'action':
|
||||||
torch.stack(list(cond_obs_queues['action']), dim=1),
|
torch.stack(list(cond_obs_queues['action']), dim=1),
|
||||||
}
|
}
|
||||||
observation = {
|
observation = _move_to_device(observation, device)
|
||||||
key: observation[key].to(device, non_blocking=True)
|
|
||||||
for key in observation
|
|
||||||
}
|
|
||||||
|
|
||||||
# Interaction with the world-model
|
# Interaction with the world-model
|
||||||
|
if log_step:
|
||||||
print(f'>>> Step {itr}: interacting with world model ...')
|
print(f'>>> Step {itr}: interacting with world model ...')
|
||||||
with profiler.profile_section("world_model_interaction"):
|
with profiler.profile_section("world_model_interaction"):
|
||||||
pred_videos_1, _, pred_states = image_guided_synthesis_sim_mode(
|
pred_videos_1, _, pred_states = image_guided_synthesis_sim_mode(
|
||||||
@@ -1364,6 +1373,12 @@ def get_parser():
|
|||||||
default="fp32",
|
default="fp32",
|
||||||
help="Dtype for VAE/first_stage_model weights and forward autocast."
|
help="Dtype for VAE/first_stage_model weights and forward autocast."
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--step_log_every",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Print per-iteration step logs every N iterations."
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--n_action_steps",
|
"--n_action_steps",
|
||||||
type=int,
|
type=int,
|
||||||
|
|||||||
@@ -28,6 +28,11 @@ class DDIMSampler(object):
|
|||||||
ddim_discretize="uniform",
|
ddim_discretize="uniform",
|
||||||
ddim_eta=0.,
|
ddim_eta=0.,
|
||||||
verbose=True):
|
verbose=True):
|
||||||
|
device = self.model.betas.device
|
||||||
|
cache_key = (ddim_num_steps, ddim_discretize, float(ddim_eta),
|
||||||
|
str(device))
|
||||||
|
if getattr(self, "_schedule_cache", None) == cache_key:
|
||||||
|
return
|
||||||
self.ddim_timesteps = make_ddim_timesteps(
|
self.ddim_timesteps = make_ddim_timesteps(
|
||||||
ddim_discr_method=ddim_discretize,
|
ddim_discr_method=ddim_discretize,
|
||||||
num_ddim_timesteps=ddim_num_steps,
|
num_ddim_timesteps=ddim_num_steps,
|
||||||
@@ -67,16 +72,26 @@ class DDIMSampler(object):
|
|||||||
ddim_timesteps=self.ddim_timesteps,
|
ddim_timesteps=self.ddim_timesteps,
|
||||||
eta=ddim_eta,
|
eta=ddim_eta,
|
||||||
verbose=verbose)
|
verbose=verbose)
|
||||||
|
ddim_sigmas = torch.as_tensor(ddim_sigmas,
|
||||||
|
device=self.model.device,
|
||||||
|
dtype=torch.float32)
|
||||||
|
ddim_alphas = torch.as_tensor(ddim_alphas,
|
||||||
|
device=self.model.device,
|
||||||
|
dtype=torch.float32)
|
||||||
|
ddim_alphas_prev = torch.as_tensor(ddim_alphas_prev,
|
||||||
|
device=self.model.device,
|
||||||
|
dtype=torch.float32)
|
||||||
self.register_buffer('ddim_sigmas', ddim_sigmas)
|
self.register_buffer('ddim_sigmas', ddim_sigmas)
|
||||||
self.register_buffer('ddim_alphas', ddim_alphas)
|
self.register_buffer('ddim_alphas', ddim_alphas)
|
||||||
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
|
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
|
||||||
self.register_buffer('ddim_sqrt_one_minus_alphas',
|
self.register_buffer('ddim_sqrt_one_minus_alphas',
|
||||||
np.sqrt(1. - ddim_alphas))
|
torch.sqrt(1. - ddim_alphas))
|
||||||
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
|
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
|
||||||
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) *
|
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) *
|
||||||
(1 - self.alphas_cumprod / self.alphas_cumprod_prev))
|
(1 - self.alphas_cumprod / self.alphas_cumprod_prev))
|
||||||
self.register_buffer('ddim_sigmas_for_original_num_steps',
|
self.register_buffer('ddim_sigmas_for_original_num_steps',
|
||||||
sigmas_for_original_sampling_steps)
|
sigmas_for_original_sampling_steps)
|
||||||
|
self._schedule_cache = cache_key
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample(
|
def sample(
|
||||||
@@ -228,10 +243,14 @@ class DDIMSampler(object):
|
|||||||
'x_inter_state': [state],
|
'x_inter_state': [state],
|
||||||
'pred_x0_state': [state],
|
'pred_x0_state': [state],
|
||||||
}
|
}
|
||||||
time_range = reversed(range(
|
if ddim_use_original_steps:
|
||||||
0, timesteps)) if ddim_use_original_steps else np.flip(timesteps)
|
time_range = np.arange(timesteps - 1, -1, -1)
|
||||||
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[
|
else:
|
||||||
0]
|
time_range = np.flip(timesteps)
|
||||||
|
time_range = np.ascontiguousarray(time_range)
|
||||||
|
total_steps = int(time_range.shape[0])
|
||||||
|
t_seq = torch.as_tensor(time_range, device=device, dtype=torch.long)
|
||||||
|
ts_batch = t_seq.unsqueeze(1).expand(total_steps, b).contiguous()
|
||||||
if verbose:
|
if verbose:
|
||||||
iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
|
iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
|
||||||
else:
|
else:
|
||||||
@@ -243,7 +262,7 @@ class DDIMSampler(object):
|
|||||||
dp_ddim_scheduler_state.set_timesteps(len(timesteps))
|
dp_ddim_scheduler_state.set_timesteps(len(timesteps))
|
||||||
for i, step in enumerate(iterator):
|
for i, step in enumerate(iterator):
|
||||||
index = total_steps - i - 1
|
index = total_steps - i - 1
|
||||||
ts = torch.full((b, ), step, device=device, dtype=torch.long)
|
ts = ts_batch[i]
|
||||||
|
|
||||||
# Use mask to blend noised original latent (img_orig) & new sampled latent (img)
|
# Use mask to blend noised original latent (img_orig) & new sampled latent (img)
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
@@ -378,16 +397,14 @@ class DDIMSampler(object):
|
|||||||
sigmas = self.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
|
sigmas = self.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
|
||||||
|
|
||||||
if is_video:
|
if is_video:
|
||||||
size = (b, 1, 1, 1, 1)
|
size = (1, 1, 1, 1, 1)
|
||||||
else:
|
else:
|
||||||
size = (b, 1, 1, 1)
|
size = (1, 1, 1, 1)
|
||||||
|
|
||||||
a_t = torch.full(size, alphas[index], device=device)
|
a_t = alphas[index].view(size)
|
||||||
a_prev = torch.full(size, alphas_prev[index], device=device)
|
a_prev = alphas_prev[index].view(size)
|
||||||
sigma_t = torch.full(size, sigmas[index], device=device)
|
sigma_t = sigmas[index].view(size)
|
||||||
sqrt_one_minus_at = torch.full(size,
|
sqrt_one_minus_at = sqrt_one_minus_alphas[index].view(size)
|
||||||
sqrt_one_minus_alphas[index],
|
|
||||||
device=device)
|
|
||||||
|
|
||||||
if self.model.parameterization != "v":
|
if self.model.parameterization != "v":
|
||||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||||
@@ -395,12 +412,8 @@ class DDIMSampler(object):
|
|||||||
pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
|
pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
|
||||||
|
|
||||||
if self.model.use_dynamic_rescale:
|
if self.model.use_dynamic_rescale:
|
||||||
scale_t = torch.full(size,
|
scale_t = self.ddim_scale_arr[index].view(size)
|
||||||
self.ddim_scale_arr[index],
|
prev_scale_t = self.ddim_scale_arr_prev[index].view(size)
|
||||||
device=device)
|
|
||||||
prev_scale_t = torch.full(size,
|
|
||||||
self.ddim_scale_arr_prev[index],
|
|
||||||
device=device)
|
|
||||||
rescale = (prev_scale_t / scale_t)
|
rescale = (prev_scale_t / scale_t)
|
||||||
pred_x0 *= rescale
|
pred_x0 *= rescale
|
||||||
|
|
||||||
|
|||||||
@@ -99,6 +99,7 @@ class CrossAttention(nn.Module):
|
|||||||
self.agent_state_context_len = agent_state_context_len
|
self.agent_state_context_len = agent_state_context_len
|
||||||
self.agent_action_context_len = agent_action_context_len
|
self.agent_action_context_len = agent_action_context_len
|
||||||
self.cross_attention_scale_learnable = cross_attention_scale_learnable
|
self.cross_attention_scale_learnable = cross_attention_scale_learnable
|
||||||
|
self._attn_mask_cache = {}
|
||||||
if self.image_cross_attention:
|
if self.image_cross_attention:
|
||||||
self.to_k_ip = nn.Linear(context_dim, inner_dim, bias=False)
|
self.to_k_ip = nn.Linear(context_dim, inner_dim, bias=False)
|
||||||
self.to_v_ip = nn.Linear(context_dim, inner_dim, bias=False)
|
self.to_v_ip = nn.Linear(context_dim, inner_dim, bias=False)
|
||||||
@@ -275,7 +276,8 @@ class CrossAttention(nn.Module):
|
|||||||
attn_mask_aa = self._get_attn_mask_aa(x.shape[0],
|
attn_mask_aa = self._get_attn_mask_aa(x.shape[0],
|
||||||
q.shape[1],
|
q.shape[1],
|
||||||
k_aa.shape[1],
|
k_aa.shape[1],
|
||||||
block_size=16).to(k_aa.device)
|
block_size=16,
|
||||||
|
device=k_aa.device)
|
||||||
else:
|
else:
|
||||||
if not spatial_self_attn:
|
if not spatial_self_attn:
|
||||||
assert 1 > 2, ">>> ERROR: you should never go into here ..."
|
assert 1 > 2, ">>> ERROR: you should never go into here ..."
|
||||||
@@ -386,14 +388,26 @@ class CrossAttention(nn.Module):
|
|||||||
|
|
||||||
return self.to_out(out)
|
return self.to_out(out)
|
||||||
|
|
||||||
def _get_attn_mask_aa(self, b, l1, l2, block_size=16):
|
def _get_attn_mask_aa(self,
|
||||||
|
b,
|
||||||
|
l1,
|
||||||
|
l2,
|
||||||
|
block_size=16,
|
||||||
|
device=None):
|
||||||
|
if device is None:
|
||||||
|
device = self.to_q.weight.device
|
||||||
|
cache_key = (b, l1, l2, block_size, str(device))
|
||||||
|
if cache_key in self._attn_mask_cache:
|
||||||
|
return self._attn_mask_cache[cache_key]
|
||||||
num_token = l2 // block_size
|
num_token = l2 // block_size
|
||||||
start_positions = ((torch.arange(b) % block_size) + 1) * num_token
|
start_positions = ((torch.arange(b, device=device) % block_size) +
|
||||||
col_indices = torch.arange(l2)
|
1) * num_token
|
||||||
|
col_indices = torch.arange(l2, device=device)
|
||||||
mask_2d = col_indices.unsqueeze(0) >= start_positions.unsqueeze(1)
|
mask_2d = col_indices.unsqueeze(0) >= start_positions.unsqueeze(1)
|
||||||
mask = mask_2d.unsqueeze(1).expand(b, l1, l2)
|
mask = mask_2d.unsqueeze(1).expand(b, l1, l2)
|
||||||
attn_mask = torch.zeros_like(mask, dtype=torch.float)
|
attn_mask = torch.zeros_like(mask, dtype=torch.float32)
|
||||||
attn_mask[mask] = float('-inf')
|
attn_mask[mask] = float('-inf')
|
||||||
|
self._attn_mask_cache[cache_key] = attn_mask
|
||||||
return attn_mask
|
return attn_mask
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
12
useful.sh
12
useful.sh
@@ -107,3 +107,15 @@ embedder:
|
|||||||
1. 新增 --encoder_mode {fp32, autocast, bf16_full}
|
1. 新增 --encoder_mode {fp32, autocast, bf16_full}
|
||||||
2. bf16_full = 权重 BF16 + 前向 BF16
|
2. bf16_full = 权重 BF16 + 前向 BF16
|
||||||
3. autocast = 权重 FP32 + 仅主干 autocast(现在的实现)
|
3. autocast = 权重 FP32 + 仅主干 autocast(现在的实现)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
1. DDIM loop 内小张量分配优化(已完成)
|
||||||
|
|
||||||
|
- 每步 torch.full(...) 改成预先构造/广播,减少 loop 内分配
|
||||||
|
- 位置:src/unifolm_wma/models/samplers/ddim.py
|
||||||
|
|
||||||
|
2. attention mask 缓存到 GPU(已完成)
|
||||||
|
|
||||||
|
- _get_attn_mask_aa 现在直接在目标 device 构造并缓存,避免每步 CPU→GPU 拷贝
|
||||||
|
- 位置:src/unifolm_wma/modules/attention.py
|
||||||
Reference in New Issue
Block a user