DDIM loop 内小张量分配优化,attention mask 缓存到 GPU,加速30s左右
This commit is contained in:
@@ -13,7 +13,7 @@ import time
|
||||
import json
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from dataclasses import dataclass, field, asdict
|
||||
from typing import Optional, Dict, List, Any
|
||||
from typing import Optional, Dict, List, Any, Mapping
|
||||
|
||||
from pytorch_lightning import seed_everything
|
||||
from omegaconf import OmegaConf
|
||||
@@ -673,8 +673,8 @@ def get_latent_z(model, videos: Tensor) -> Tensor:
|
||||
return z
|
||||
|
||||
|
||||
def preprocess_observation(
|
||||
model, observations: dict[str, np.ndarray]) -> dict[str, Tensor]:
|
||||
def preprocess_observation(
|
||||
model, observations: dict[str, np.ndarray]) -> dict[str, Tensor]:
|
||||
"""Convert environment observation to LeRobot format observation.
|
||||
Args:
|
||||
observation: Dictionary of observation batches from a Gym vector environment.
|
||||
@@ -715,7 +715,18 @@ def preprocess_observation(
|
||||
return_observations['observation.state'].to(model.device)
|
||||
})['observation.state']
|
||||
|
||||
return return_observations
|
||||
return return_observations
|
||||
|
||||
|
||||
def _move_to_device(batch: Mapping[str, Any],
|
||||
device: torch.device) -> dict[str, Any]:
|
||||
moved = {}
|
||||
for key, value in batch.items():
|
||||
if isinstance(value, torch.Tensor) and value.device != device:
|
||||
moved[key] = value.to(device, non_blocking=True)
|
||||
else:
|
||||
moved[key] = value
|
||||
return moved
|
||||
|
||||
|
||||
def image_guided_synthesis_sim_mode(
|
||||
@@ -768,8 +779,11 @@ def image_guided_synthesis_sim_mode(
|
||||
profiler = get_profiler()
|
||||
|
||||
b, _, t, _, _ = noise_shape
|
||||
ddim_sampler = DDIMSampler(model)
|
||||
batch_size = noise_shape[0]
|
||||
ddim_sampler = getattr(model, "_ddim_sampler", None)
|
||||
if ddim_sampler is None:
|
||||
ddim_sampler = DDIMSampler(model)
|
||||
model._ddim_sampler = ddim_sampler
|
||||
batch_size = noise_shape[0]
|
||||
|
||||
fs = torch.tensor([fs] * batch_size, dtype=torch.long, device=model.device)
|
||||
|
||||
@@ -900,7 +914,7 @@ def image_guided_synthesis_sim_mode(
|
||||
return batch_variants, actions, states
|
||||
|
||||
|
||||
def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
"""
|
||||
Run inference pipeline on prompts and image inputs.
|
||||
|
||||
@@ -912,7 +926,7 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
profiler = get_profiler()
|
||||
profiler = get_profiler()
|
||||
|
||||
# Create inference and tensorboard dirs
|
||||
os.makedirs(args.savedir + '/inference', exist_ok=True)
|
||||
@@ -1077,10 +1091,7 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
'action':
|
||||
torch.zeros_like(batch['action'][-1]).unsqueeze(0)
|
||||
}
|
||||
observation = {
|
||||
key: observation[key].to(device, non_blocking=True)
|
||||
for key in observation
|
||||
}
|
||||
observation = _move_to_device(observation, device)
|
||||
# Update observation queues
|
||||
cond_obs_queues = populate_queues(cond_obs_queues, observation)
|
||||
|
||||
@@ -1093,7 +1104,9 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
|
||||
# Multi-round interaction with the world-model
|
||||
with pytorch_prof_ctx:
|
||||
for itr in tqdm(range(args.n_iter)):
|
||||
for itr in tqdm(range(args.n_iter)):
|
||||
log_every = max(1, args.step_log_every)
|
||||
log_step = (itr % log_every == 0)
|
||||
profiler.current_iteration = itr
|
||||
profiler.record_memory(f"iter_{itr}_start")
|
||||
|
||||
@@ -1111,13 +1124,11 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
'action':
|
||||
torch.stack(list(cond_obs_queues['action']), dim=1),
|
||||
}
|
||||
observation = {
|
||||
key: observation[key].to(device, non_blocking=True)
|
||||
for key in observation
|
||||
}
|
||||
observation = _move_to_device(observation, device)
|
||||
|
||||
# Use world-model in policy to generate action
|
||||
print(f'>>> Step {itr}: generating actions ...')
|
||||
if log_step:
|
||||
print(f'>>> Step {itr}: generating actions ...')
|
||||
with profiler.profile_section("action_generation"):
|
||||
pred_videos_0, pred_actions, _ = image_guided_synthesis_sim_mode(
|
||||
model,
|
||||
@@ -1156,13 +1167,11 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
'action':
|
||||
torch.stack(list(cond_obs_queues['action']), dim=1),
|
||||
}
|
||||
observation = {
|
||||
key: observation[key].to(device, non_blocking=True)
|
||||
for key in observation
|
||||
}
|
||||
observation = _move_to_device(observation, device)
|
||||
|
||||
# Interaction with the world-model
|
||||
print(f'>>> Step {itr}: interacting with world model ...')
|
||||
if log_step:
|
||||
print(f'>>> Step {itr}: interacting with world model ...')
|
||||
with profiler.profile_section("world_model_interaction"):
|
||||
pred_videos_1, _, pred_states = image_guided_synthesis_sim_mode(
|
||||
model,
|
||||
@@ -1364,6 +1373,12 @@ def get_parser():
|
||||
default="fp32",
|
||||
help="Dtype for VAE/first_stage_model weights and forward autocast."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--step_log_every",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Print per-iteration step logs every N iterations."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--n_action_steps",
|
||||
type=int,
|
||||
|
||||
Reference in New Issue
Block a user