DDIM loop 内小张量分配优化,attention mask 缓存到 GPU,加速30s左右

This commit is contained in:
2026-01-18 22:37:55 +08:00
parent a90efc6718
commit cb334f308b
9 changed files with 103 additions and 49 deletions

View File

@@ -13,7 +13,7 @@ import time
import json
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass, field, asdict
from typing import Optional, Dict, List, Any
from typing import Optional, Dict, List, Any, Mapping
from pytorch_lightning import seed_everything
from omegaconf import OmegaConf
@@ -673,8 +673,8 @@ def get_latent_z(model, videos: Tensor) -> Tensor:
return z
def preprocess_observation(
model, observations: dict[str, np.ndarray]) -> dict[str, Tensor]:
def preprocess_observation(
model, observations: dict[str, np.ndarray]) -> dict[str, Tensor]:
"""Convert environment observation to LeRobot format observation.
Args:
observation: Dictionary of observation batches from a Gym vector environment.
@@ -715,7 +715,18 @@ def preprocess_observation(
return_observations['observation.state'].to(model.device)
})['observation.state']
return return_observations
return return_observations
def _move_to_device(batch: Mapping[str, Any],
device: torch.device) -> dict[str, Any]:
moved = {}
for key, value in batch.items():
if isinstance(value, torch.Tensor) and value.device != device:
moved[key] = value.to(device, non_blocking=True)
else:
moved[key] = value
return moved
def image_guided_synthesis_sim_mode(
@@ -768,8 +779,11 @@ def image_guided_synthesis_sim_mode(
profiler = get_profiler()
b, _, t, _, _ = noise_shape
ddim_sampler = DDIMSampler(model)
batch_size = noise_shape[0]
ddim_sampler = getattr(model, "_ddim_sampler", None)
if ddim_sampler is None:
ddim_sampler = DDIMSampler(model)
model._ddim_sampler = ddim_sampler
batch_size = noise_shape[0]
fs = torch.tensor([fs] * batch_size, dtype=torch.long, device=model.device)
@@ -900,7 +914,7 @@ def image_guided_synthesis_sim_mode(
return batch_variants, actions, states
def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
"""
Run inference pipeline on prompts and image inputs.
@@ -912,7 +926,7 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
Returns:
None
"""
profiler = get_profiler()
profiler = get_profiler()
# Create inference and tensorboard dirs
os.makedirs(args.savedir + '/inference', exist_ok=True)
@@ -1077,10 +1091,7 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
'action':
torch.zeros_like(batch['action'][-1]).unsqueeze(0)
}
observation = {
key: observation[key].to(device, non_blocking=True)
for key in observation
}
observation = _move_to_device(observation, device)
# Update observation queues
cond_obs_queues = populate_queues(cond_obs_queues, observation)
@@ -1093,7 +1104,9 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
# Multi-round interaction with the world-model
with pytorch_prof_ctx:
for itr in tqdm(range(args.n_iter)):
for itr in tqdm(range(args.n_iter)):
log_every = max(1, args.step_log_every)
log_step = (itr % log_every == 0)
profiler.current_iteration = itr
profiler.record_memory(f"iter_{itr}_start")
@@ -1111,13 +1124,11 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
'action':
torch.stack(list(cond_obs_queues['action']), dim=1),
}
observation = {
key: observation[key].to(device, non_blocking=True)
for key in observation
}
observation = _move_to_device(observation, device)
# Use world-model in policy to generate action
print(f'>>> Step {itr}: generating actions ...')
if log_step:
print(f'>>> Step {itr}: generating actions ...')
with profiler.profile_section("action_generation"):
pred_videos_0, pred_actions, _ = image_guided_synthesis_sim_mode(
model,
@@ -1156,13 +1167,11 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
'action':
torch.stack(list(cond_obs_queues['action']), dim=1),
}
observation = {
key: observation[key].to(device, non_blocking=True)
for key in observation
}
observation = _move_to_device(observation, device)
# Interaction with the world-model
print(f'>>> Step {itr}: interacting with world model ...')
if log_step:
print(f'>>> Step {itr}: interacting with world model ...')
with profiler.profile_section("world_model_interaction"):
pred_videos_1, _, pred_states = image_guided_synthesis_sim_mode(
model,
@@ -1364,6 +1373,12 @@ def get_parser():
default="fp32",
help="Dtype for VAE/first_stage_model weights and forward autocast."
)
parser.add_argument(
"--step_log_every",
type=int,
default=1,
help="Print per-iteration step logs every N iterations."
)
parser.add_argument(
"--n_action_steps",
type=int,