添加三层迭代级性能分析工具 profile_iteration.py
Layer1: CUDA Events 精确测量每个itr内10个阶段耗时 Layer2: torch.profiler GPU timeline trace Layer3: CSV输出支持A/B对比 Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
975
scripts/evaluation/profile_iteration.py
Normal file
975
scripts/evaluation/profile_iteration.py
Normal file
@@ -0,0 +1,975 @@
|
|||||||
|
"""
|
||||||
|
Profile the full iteration loop of world model interaction.
|
||||||
|
|
||||||
|
Three layers of profiling:
|
||||||
|
Layer 1: Iteration-level wall-clock breakdown (CUDA events)
|
||||||
|
Layer 2: GPU timeline trace (torch.profiler → Chrome trace)
|
||||||
|
Layer 3: A/B comparison (standardized CSV output)
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
# Layer 1 only (fast, default):
|
||||||
|
TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 CUDA_VISIBLE_DEVICES=0 \
|
||||||
|
python scripts/evaluation/profile_iteration.py \
|
||||||
|
--ckpt_path ckpts/unifolm_wma_dual_mix_bf16.ckpt \
|
||||||
|
--config configs/inference/world_model_interaction.yaml \
|
||||||
|
--prompt_dir unitree_z1_dual_arm_cleanup_pencils/case1/world_model_interaction_prompts \
|
||||||
|
--dataset unitree_z1_dual_arm_cleanup_pencils \
|
||||||
|
--frame_stride 4 --n_iter 5
|
||||||
|
|
||||||
|
# Layer 1 + Layer 2 (GPU trace):
|
||||||
|
... --trace --trace_dir ./profile_traces
|
||||||
|
|
||||||
|
# Layer 3 (A/B comparison): run twice, diff the CSVs
|
||||||
|
... --csv baseline.csv
|
||||||
|
... --csv optimized.csv
|
||||||
|
python scripts/evaluation/profile_iteration.py --compare baseline.csv optimized.csv
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import csv
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from collections import defaultdict, deque
|
||||||
|
from contextlib import nullcontext
|
||||||
|
|
||||||
|
import h5py
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
import torch
|
||||||
|
import torchvision
|
||||||
|
from einops import rearrange, repeat
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
from PIL import Image
|
||||||
|
from pytorch_lightning import seed_everything
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
from unifolm_wma.models.samplers.ddim import DDIMSampler
|
||||||
|
from unifolm_wma.utils.utils import instantiate_from_config
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
# ──────────────────────────────────────────────────────────────────────
|
||||||
|
# Constants
|
||||||
|
# ──────────────────────────────────────────────────────────────────────
|
||||||
|
STAGE_NAMES = [
|
||||||
|
"stack_to_device_1",
|
||||||
|
"synth_policy",
|
||||||
|
"update_action_queue",
|
||||||
|
"stack_to_device_2",
|
||||||
|
"synth_world_model",
|
||||||
|
"update_obs_queue",
|
||||||
|
"tensorboard_log",
|
||||||
|
"save_results",
|
||||||
|
"cpu_transfer",
|
||||||
|
"itr_total",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Sub-stages inside image_guided_synthesis_sim_mode
|
||||||
|
SYNTH_SUB_STAGES = [
|
||||||
|
"ddim_sampler_init",
|
||||||
|
"image_embedding",
|
||||||
|
"vae_encode",
|
||||||
|
"text_conditioning",
|
||||||
|
"projectors",
|
||||||
|
"cond_assembly",
|
||||||
|
"ddim_sampling",
|
||||||
|
"vae_decode",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# ──────────────────────────────────────────────────────────────────────
|
||||||
|
# CudaTimer — GPU-precise timing via CUDA events
|
||||||
|
# ──────────────────────────────────────────────────────────────────────
|
||||||
|
class CudaTimer:
|
||||||
|
"""Context manager that records GPU time between enter/exit using CUDA events."""
|
||||||
|
|
||||||
|
def __init__(self, name, records):
|
||||||
|
self.name = name
|
||||||
|
self.records = records
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
self._start = torch.cuda.Event(enable_timing=True)
|
||||||
|
self._end = torch.cuda.Event(enable_timing=True)
|
||||||
|
self._start.record()
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, *args):
|
||||||
|
self._end.record()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
elapsed_ms = self._start.elapsed_time(self._end)
|
||||||
|
self.records[self.name].append(elapsed_ms)
|
||||||
|
|
||||||
|
|
||||||
|
class WallTimer:
|
||||||
|
"""Context manager that records CPU wall-clock time (for pure-CPU stages)."""
|
||||||
|
|
||||||
|
def __init__(self, name, records):
|
||||||
|
self.name = name
|
||||||
|
self.records = records
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
self._t0 = time.perf_counter()
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, *args):
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
elapsed_ms = (time.perf_counter() - self._t0) * 1000.0
|
||||||
|
self.records[self.name].append(elapsed_ms)
|
||||||
|
|
||||||
|
|
||||||
|
# ──────────────────────────────────────────────────────────────────────
|
||||||
|
# Model loading (reused from world_model_interaction.py)
|
||||||
|
# ──────────────────────────────────────────────────────────────────────
|
||||||
|
def patch_norm_bypass_autocast():
|
||||||
|
def _group_norm_forward(self, x):
|
||||||
|
with torch.amp.autocast('cuda', enabled=False):
|
||||||
|
return F.group_norm(
|
||||||
|
x, self.num_groups,
|
||||||
|
self.weight.to(x.dtype) if self.weight is not None else None,
|
||||||
|
self.bias.to(x.dtype) if self.bias is not None else None,
|
||||||
|
self.eps)
|
||||||
|
|
||||||
|
def _layer_norm_forward(self, x):
|
||||||
|
with torch.amp.autocast('cuda', enabled=False):
|
||||||
|
return F.layer_norm(
|
||||||
|
x, self.normalized_shape,
|
||||||
|
self.weight.to(x.dtype) if self.weight is not None else None,
|
||||||
|
self.bias.to(x.dtype) if self.bias is not None else None,
|
||||||
|
self.eps)
|
||||||
|
|
||||||
|
torch.nn.GroupNorm.forward = _group_norm_forward
|
||||||
|
torch.nn.LayerNorm.forward = _layer_norm_forward
|
||||||
|
|
||||||
|
|
||||||
|
def apply_torch_compile(model, hot_indices=(5, 8, 9)):
|
||||||
|
from unifolm_wma.modules.networks.wma_model import ResBlock
|
||||||
|
unet = model.model.diffusion_model
|
||||||
|
compiled = 0
|
||||||
|
for idx in hot_indices:
|
||||||
|
block = unet.output_blocks[idx]
|
||||||
|
for layer in block:
|
||||||
|
if isinstance(layer, ResBlock):
|
||||||
|
layer._forward = torch.compile(layer._forward, mode="default")
|
||||||
|
compiled += 1
|
||||||
|
print(f" torch.compile: {compiled} ResBlocks in output_blocks{list(hot_indices)}")
|
||||||
|
|
||||||
|
|
||||||
|
def load_model(args):
|
||||||
|
config = OmegaConf.load(args.config)
|
||||||
|
config['model']['params']['wma_config']['params']['use_checkpoint'] = False
|
||||||
|
model = instantiate_from_config(config.model)
|
||||||
|
model.perframe_ae = args.perframe_ae
|
||||||
|
|
||||||
|
from collections import OrderedDict
|
||||||
|
state_dict = torch.load(args.ckpt_path, map_location="cpu")
|
||||||
|
if "state_dict" in state_dict:
|
||||||
|
state_dict = state_dict["state_dict"]
|
||||||
|
try:
|
||||||
|
model.load_state_dict(state_dict, strict=True)
|
||||||
|
except Exception:
|
||||||
|
new_sd = OrderedDict()
|
||||||
|
for k, v in state_dict.items():
|
||||||
|
new_sd[k] = v
|
||||||
|
for k in list(new_sd.keys()):
|
||||||
|
if "framestride_embed" in k:
|
||||||
|
new_sd[k.replace("framestride_embed", "fps_embedding")] = new_sd.pop(k)
|
||||||
|
model.load_state_dict(new_sd, strict=True)
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
# Apply precision: bf16 diffusion + encoders + projectors, fp32/bf16 VAE
|
||||||
|
model.model.to(torch.bfloat16)
|
||||||
|
model.diffusion_autocast_dtype = torch.bfloat16
|
||||||
|
model.embedder.to(torch.bfloat16)
|
||||||
|
model.image_proj_model.to(torch.bfloat16)
|
||||||
|
model.encoder_autocast_dtype = None
|
||||||
|
model.state_projector.to(torch.bfloat16)
|
||||||
|
model.action_projector.to(torch.bfloat16)
|
||||||
|
model.projector_autocast_dtype = None
|
||||||
|
if args.vae_dtype == "bf16":
|
||||||
|
model.first_stage_model.to(torch.bfloat16)
|
||||||
|
|
||||||
|
# Compile hot ResBlocks
|
||||||
|
apply_torch_compile(model)
|
||||||
|
model = model.cuda()
|
||||||
|
print(">>> Model loaded and ready.")
|
||||||
|
return model, config
|
||||||
|
|
||||||
|
|
||||||
|
# ──────────────────────────────────────────────────────────────────────
|
||||||
|
# Data preparation (reused from world_model_interaction.py)
|
||||||
|
# ──────────────────────────────────────────────────────────────────────
|
||||||
|
def get_init_frame_path(data_dir, sample):
|
||||||
|
rel = os.path.join(sample['data_dir'], str(sample['videoid']) + '.png')
|
||||||
|
return os.path.join(data_dir, 'images', rel)
|
||||||
|
|
||||||
|
|
||||||
|
def get_transition_path(data_dir, sample):
|
||||||
|
rel = os.path.join(sample['data_dir'], str(sample['videoid']) + '.h5')
|
||||||
|
return os.path.join(data_dir, 'transitions', rel)
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_init_input(start_idx, init_frame_path, transition_dict,
|
||||||
|
frame_stride, wma_data, video_length=16, n_obs_steps=2):
|
||||||
|
indices = [start_idx + frame_stride * i for i in range(video_length)]
|
||||||
|
init_frame = Image.open(init_frame_path).convert('RGB')
|
||||||
|
init_frame = torch.tensor(np.array(init_frame)).unsqueeze(0).permute(3, 0, 1, 2).float()
|
||||||
|
|
||||||
|
if start_idx < n_obs_steps - 1:
|
||||||
|
state_indices = list(range(0, start_idx + 1))
|
||||||
|
states = transition_dict['observation.state'][state_indices, :]
|
||||||
|
num_padding = n_obs_steps - 1 - start_idx
|
||||||
|
padding = states[0:1, :].repeat(num_padding, 1)
|
||||||
|
states = torch.cat((padding, states), dim=0)
|
||||||
|
else:
|
||||||
|
state_indices = list(range(start_idx - n_obs_steps + 1, start_idx + 1))
|
||||||
|
states = transition_dict['observation.state'][state_indices, :]
|
||||||
|
|
||||||
|
actions = transition_dict['action'][indices, :]
|
||||||
|
ori_state_dim = states.shape[-1]
|
||||||
|
ori_action_dim = actions.shape[-1]
|
||||||
|
|
||||||
|
frames_action_state_dict = {
|
||||||
|
'action': actions,
|
||||||
|
'observation.state': states,
|
||||||
|
}
|
||||||
|
frames_action_state_dict = wma_data.normalizer(frames_action_state_dict)
|
||||||
|
frames_action_state_dict = wma_data.get_uni_vec(
|
||||||
|
frames_action_state_dict,
|
||||||
|
transition_dict['action_type'],
|
||||||
|
transition_dict['state_type'],
|
||||||
|
)
|
||||||
|
|
||||||
|
if wma_data.spatial_transform is not None:
|
||||||
|
init_frame = wma_data.spatial_transform(init_frame)
|
||||||
|
init_frame = (init_frame / 255 - 0.5) * 2
|
||||||
|
|
||||||
|
data = {'observation.image': init_frame}
|
||||||
|
data.update(frames_action_state_dict)
|
||||||
|
return data, ori_state_dim, ori_action_dim
|
||||||
|
|
||||||
|
|
||||||
|
def populate_queues(queues, batch):
|
||||||
|
for key in batch:
|
||||||
|
if key not in queues:
|
||||||
|
continue
|
||||||
|
if len(queues[key]) != queues[key].maxlen:
|
||||||
|
while len(queues[key]) != queues[key].maxlen:
|
||||||
|
queues[key].append(batch[key])
|
||||||
|
else:
|
||||||
|
queues[key].append(batch[key])
|
||||||
|
return queues
|
||||||
|
|
||||||
|
|
||||||
|
# ──────────────────────────────────────────────────────────────────────
|
||||||
|
# Instrumented image_guided_synthesis_sim_mode with sub-stage timing
|
||||||
|
# ──────────────────────────────────────────────────────────────────────
|
||||||
|
def get_latent_z(model, videos):
|
||||||
|
b, c, t, h, w = videos.shape
|
||||||
|
x = rearrange(videos, 'b c t h w -> (b t) c h w')
|
||||||
|
vae_dtype = next(model.first_stage_model.parameters()).dtype
|
||||||
|
x = x.to(dtype=vae_dtype)
|
||||||
|
z = model.encode_first_stage(x)
|
||||||
|
z = rearrange(z, '(b t) c h w -> b c t h w', b=b, t=t)
|
||||||
|
return z
|
||||||
|
|
||||||
|
|
||||||
|
def save_results(video, filename, fps=8):
|
||||||
|
video = video.detach().cpu()
|
||||||
|
video = torch.clamp(video.float(), -1., 1.)
|
||||||
|
n = video.shape[0]
|
||||||
|
video = video.permute(2, 0, 1, 3, 4)
|
||||||
|
frame_grids = [
|
||||||
|
torchvision.utils.make_grid(framesheet, nrow=int(n), padding=0)
|
||||||
|
for framesheet in video
|
||||||
|
]
|
||||||
|
grid = torch.stack(frame_grids, dim=0)
|
||||||
|
grid = (grid + 1.0) / 2.0
|
||||||
|
grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1)
|
||||||
|
torchvision.io.write_video(filename, grid, fps=fps,
|
||||||
|
video_codec='h264', options={'crf': '10'})
|
||||||
|
|
||||||
|
|
||||||
|
def profiled_synthesis(model, prompts, observation, noise_shape,
|
||||||
|
ddim_steps, ddim_eta, unconditional_guidance_scale,
|
||||||
|
fs, text_input, timestep_spacing, guidance_rescale,
|
||||||
|
sim_mode, decode_video, records, prefix):
|
||||||
|
"""image_guided_synthesis_sim_mode with per-sub-stage CUDA event timing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prefix: "policy" or "wm" — prepended to sub-stage names in records.
|
||||||
|
"""
|
||||||
|
b, _, t, _, _ = noise_shape
|
||||||
|
batch_size = noise_shape[0]
|
||||||
|
device = next(model.parameters()).device
|
||||||
|
|
||||||
|
# --- sub-stage: ddim_sampler_init ---
|
||||||
|
with CudaTimer(f"{prefix}/ddim_sampler_init", records):
|
||||||
|
ddim_sampler = DDIMSampler(model)
|
||||||
|
fs_t = torch.tensor([fs] * batch_size, dtype=torch.long, device=device)
|
||||||
|
|
||||||
|
# --- sub-stage: image_embedding ---
|
||||||
|
with CudaTimer(f"{prefix}/image_embedding", records):
|
||||||
|
model_dtype = next(model.embedder.parameters()).dtype
|
||||||
|
img = observation['observation.images.top'].permute(0, 2, 1, 3, 4)
|
||||||
|
cond_img = rearrange(img, 'b o c h w -> (b o) c h w')[-1:].to(dtype=model_dtype)
|
||||||
|
cond_img_emb = model.embedder(cond_img)
|
||||||
|
cond_img_emb = model.image_proj_model(cond_img_emb)
|
||||||
|
|
||||||
|
# --- sub-stage: vae_encode ---
|
||||||
|
with CudaTimer(f"{prefix}/vae_encode", records):
|
||||||
|
if model.model.conditioning_key == 'hybrid':
|
||||||
|
z = get_latent_z(model, img.permute(0, 2, 1, 3, 4))
|
||||||
|
img_cat_cond = z[:, :, -1:, :, :]
|
||||||
|
img_cat_cond = repeat(img_cat_cond,
|
||||||
|
'b c t h w -> b c (repeat t) h w',
|
||||||
|
repeat=noise_shape[2])
|
||||||
|
cond = {"c_concat": [img_cat_cond]}
|
||||||
|
|
||||||
|
# --- sub-stage: text_conditioning ---
|
||||||
|
with CudaTimer(f"{prefix}/text_conditioning", records):
|
||||||
|
if not text_input:
|
||||||
|
prompts_use = [""] * batch_size
|
||||||
|
else:
|
||||||
|
prompts_use = prompts if isinstance(prompts, list) else [prompts] * batch_size
|
||||||
|
cond_ins_emb = model.get_learned_conditioning(prompts_use)
|
||||||
|
|
||||||
|
# --- sub-stage: projectors ---
|
||||||
|
with CudaTimer(f"{prefix}/projectors", records):
|
||||||
|
projector_dtype = next(model.state_projector.parameters()).dtype
|
||||||
|
cond_state_emb = model.state_projector(
|
||||||
|
observation['observation.state'].to(dtype=projector_dtype))
|
||||||
|
cond_state_emb = cond_state_emb + model.agent_state_pos_emb
|
||||||
|
cond_action_emb = model.action_projector(
|
||||||
|
observation['action'].to(dtype=projector_dtype))
|
||||||
|
cond_action_emb = cond_action_emb + model.agent_action_pos_emb
|
||||||
|
if not sim_mode:
|
||||||
|
cond_action_emb = torch.zeros_like(cond_action_emb)
|
||||||
|
|
||||||
|
# --- sub-stage: cond_assembly ---
|
||||||
|
with CudaTimer(f"{prefix}/cond_assembly", records):
|
||||||
|
cond["c_crossattn"] = [
|
||||||
|
torch.cat([cond_state_emb, cond_action_emb, cond_ins_emb, cond_img_emb], dim=1)
|
||||||
|
]
|
||||||
|
cond["c_crossattn_action"] = [
|
||||||
|
observation['observation.images.top'][:, :, -model.n_obs_steps_acting:],
|
||||||
|
observation['observation.state'][:, -model.n_obs_steps_acting:],
|
||||||
|
sim_mode,
|
||||||
|
False,
|
||||||
|
]
|
||||||
|
|
||||||
|
# --- sub-stage: ddim_sampling ---
|
||||||
|
autocast_dtype = getattr(model, 'diffusion_autocast_dtype', None)
|
||||||
|
if autocast_dtype is not None and device.type == 'cuda':
|
||||||
|
autocast_ctx = torch.autocast('cuda', dtype=autocast_dtype)
|
||||||
|
else:
|
||||||
|
autocast_ctx = nullcontext()
|
||||||
|
|
||||||
|
with CudaTimer(f"{prefix}/ddim_sampling", records):
|
||||||
|
with autocast_ctx:
|
||||||
|
samples, actions, states, _ = ddim_sampler.sample(
|
||||||
|
S=ddim_steps,
|
||||||
|
conditioning=cond,
|
||||||
|
batch_size=batch_size,
|
||||||
|
shape=noise_shape[1:],
|
||||||
|
verbose=False,
|
||||||
|
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||||
|
unconditional_conditioning=None,
|
||||||
|
eta=ddim_eta,
|
||||||
|
cfg_img=None,
|
||||||
|
mask=None,
|
||||||
|
x0=None,
|
||||||
|
fs=fs_t,
|
||||||
|
timestep_spacing=timestep_spacing,
|
||||||
|
guidance_rescale=guidance_rescale,
|
||||||
|
unconditional_conditioning_img_nonetext=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- sub-stage: vae_decode ---
|
||||||
|
batch_variants = None
|
||||||
|
if decode_video:
|
||||||
|
with CudaTimer(f"{prefix}/vae_decode", records):
|
||||||
|
batch_variants = model.decode_first_stage(samples)
|
||||||
|
else:
|
||||||
|
records[f"{prefix}/vae_decode"].append(0.0)
|
||||||
|
|
||||||
|
return batch_variants, actions, states
|
||||||
|
|
||||||
|
|
||||||
|
# ──────────────────────────────────────────────────────────────────────
|
||||||
|
# Instrumented iteration loop
|
||||||
|
# ──────────────────────────────────────────────────────────────────────
|
||||||
|
def run_profiled_iterations(model, args, config, noise_shape, device):
|
||||||
|
"""Run the full iteration loop with per-stage timing.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
all_records: list of dicts, one per itr, {stage_name: ms}
|
||||||
|
"""
|
||||||
|
# Load data
|
||||||
|
csv_path = os.path.join(args.prompt_dir, f"{args.dataset}.csv")
|
||||||
|
df = pd.read_csv(csv_path)
|
||||||
|
sample = df.iloc[0]
|
||||||
|
|
||||||
|
data_module = instantiate_from_config(config.data)
|
||||||
|
data_module.setup()
|
||||||
|
|
||||||
|
init_frame_path = get_init_frame_path(args.prompt_dir, sample)
|
||||||
|
ori_fps = float(sample['fps'])
|
||||||
|
fs = args.frame_stride
|
||||||
|
model_input_fs = ori_fps // fs
|
||||||
|
|
||||||
|
transition_path = get_transition_path(args.prompt_dir, sample)
|
||||||
|
with h5py.File(transition_path, 'r') as h5f:
|
||||||
|
transition_dict = {}
|
||||||
|
for key in h5f.keys():
|
||||||
|
transition_dict[key] = torch.tensor(h5f[key][()])
|
||||||
|
for key in h5f.attrs.keys():
|
||||||
|
transition_dict[key] = h5f.attrs[key]
|
||||||
|
|
||||||
|
# Prepare initial observation
|
||||||
|
batch, ori_state_dim, ori_action_dim = prepare_init_input(
|
||||||
|
0, init_frame_path, transition_dict, fs,
|
||||||
|
data_module.test_datasets[args.dataset],
|
||||||
|
n_obs_steps=model.n_obs_steps_imagen)
|
||||||
|
|
||||||
|
observation = {
|
||||||
|
'observation.images.top':
|
||||||
|
batch['observation.image'].permute(1, 0, 2, 3)[-1].unsqueeze(0),
|
||||||
|
'observation.state':
|
||||||
|
batch['observation.state'][-1].unsqueeze(0),
|
||||||
|
'action':
|
||||||
|
torch.zeros_like(batch['action'][-1]).unsqueeze(0),
|
||||||
|
}
|
||||||
|
observation = {k: v.to(device, non_blocking=True) for k, v in observation.items()}
|
||||||
|
|
||||||
|
cond_obs_queues = {
|
||||||
|
"observation.images.top": deque(maxlen=model.n_obs_steps_imagen),
|
||||||
|
"observation.state": deque(maxlen=model.n_obs_steps_imagen),
|
||||||
|
"action": deque(maxlen=args.video_length),
|
||||||
|
}
|
||||||
|
cond_obs_queues = populate_queues(cond_obs_queues, observation)
|
||||||
|
|
||||||
|
# Temp dir for save_results profiling
|
||||||
|
tmp_dir = os.path.join(args.savedir, "profile_tmp")
|
||||||
|
os.makedirs(tmp_dir, exist_ok=True)
|
||||||
|
|
||||||
|
prompt_text = sample['instruction']
|
||||||
|
all_records = []
|
||||||
|
|
||||||
|
print(f">>> Running {args.n_iter} profiled iterations ...")
|
||||||
|
for itr in range(args.n_iter):
|
||||||
|
rec = defaultdict(list)
|
||||||
|
|
||||||
|
# ── itr_total start ──
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
itr_start = torch.cuda.Event(enable_timing=True)
|
||||||
|
itr_end = torch.cuda.Event(enable_timing=True)
|
||||||
|
itr_start.record()
|
||||||
|
|
||||||
|
# ① stack_to_device_1
|
||||||
|
with CudaTimer("stack_to_device_1", rec):
|
||||||
|
observation = {
|
||||||
|
'observation.images.top':
|
||||||
|
torch.stack(list(cond_obs_queues['observation.images.top']),
|
||||||
|
dim=1).permute(0, 2, 1, 3, 4),
|
||||||
|
'observation.state':
|
||||||
|
torch.stack(list(cond_obs_queues['observation.state']), dim=1),
|
||||||
|
'action':
|
||||||
|
torch.stack(list(cond_obs_queues['action']), dim=1),
|
||||||
|
}
|
||||||
|
observation = {k: v.to(device, non_blocking=True) for k, v in observation.items()}
|
||||||
|
|
||||||
|
# ② synth_policy
|
||||||
|
with CudaTimer("synth_policy", rec):
|
||||||
|
pred_videos_0, pred_actions, _ = profiled_synthesis(
|
||||||
|
model, prompt_text, observation, noise_shape,
|
||||||
|
ddim_steps=args.ddim_steps, ddim_eta=args.ddim_eta,
|
||||||
|
unconditional_guidance_scale=args.unconditional_guidance_scale,
|
||||||
|
fs=model_input_fs, text_input=True,
|
||||||
|
timestep_spacing=args.timestep_spacing,
|
||||||
|
guidance_rescale=args.guidance_rescale,
|
||||||
|
sim_mode=False,
|
||||||
|
decode_video=not args.fast_policy_no_decode,
|
||||||
|
records=rec, prefix="policy")
|
||||||
|
|
||||||
|
# ③ update_action_queue
|
||||||
|
with WallTimer("update_action_queue", rec):
|
||||||
|
for idx in range(len(pred_actions[0])):
|
||||||
|
obs_a = {'action': pred_actions[0][idx:idx + 1]}
|
||||||
|
obs_a['action'][:, ori_action_dim:] = 0.0
|
||||||
|
cond_obs_queues = populate_queues(cond_obs_queues, obs_a)
|
||||||
|
|
||||||
|
# ④ stack_to_device_2
|
||||||
|
with CudaTimer("stack_to_device_2", rec):
|
||||||
|
observation = {
|
||||||
|
'observation.images.top':
|
||||||
|
torch.stack(list(cond_obs_queues['observation.images.top']),
|
||||||
|
dim=1).permute(0, 2, 1, 3, 4),
|
||||||
|
'observation.state':
|
||||||
|
torch.stack(list(cond_obs_queues['observation.state']), dim=1),
|
||||||
|
'action':
|
||||||
|
torch.stack(list(cond_obs_queues['action']), dim=1),
|
||||||
|
}
|
||||||
|
observation = {k: v.to(device, non_blocking=True) for k, v in observation.items()}
|
||||||
|
|
||||||
|
# ⑤ synth_world_model
|
||||||
|
with CudaTimer("synth_world_model", rec):
|
||||||
|
pred_videos_1, _, pred_states = profiled_synthesis(
|
||||||
|
model, "", observation, noise_shape,
|
||||||
|
ddim_steps=args.ddim_steps, ddim_eta=args.ddim_eta,
|
||||||
|
unconditional_guidance_scale=args.unconditional_guidance_scale,
|
||||||
|
fs=model_input_fs, text_input=False,
|
||||||
|
timestep_spacing=args.timestep_spacing,
|
||||||
|
guidance_rescale=args.guidance_rescale,
|
||||||
|
sim_mode=True, decode_video=True,
|
||||||
|
records=rec, prefix="wm")
|
||||||
|
|
||||||
|
# ⑥ update_obs_queue
|
||||||
|
with WallTimer("update_obs_queue", rec):
|
||||||
|
for idx in range(args.exe_steps):
|
||||||
|
obs_u = {
|
||||||
|
'observation.images.top':
|
||||||
|
pred_videos_1[0][:, idx:idx + 1].permute(1, 0, 2, 3),
|
||||||
|
'observation.state':
|
||||||
|
pred_states[0][idx:idx + 1],
|
||||||
|
'action':
|
||||||
|
torch.zeros_like(pred_actions[0][-1:]),
|
||||||
|
}
|
||||||
|
obs_u['observation.state'][:, ori_state_dim:] = 0.0
|
||||||
|
cond_obs_queues = populate_queues(cond_obs_queues, obs_u)
|
||||||
|
|
||||||
|
# ⑦ tensorboard_log (simulate — no actual writer, measure make_grid cost)
|
||||||
|
with WallTimer("tensorboard_log", rec):
|
||||||
|
for vid in [pred_videos_0, pred_videos_1]:
|
||||||
|
if vid is not None and vid.dim() == 5:
|
||||||
|
v = vid.permute(2, 0, 1, 3, 4)
|
||||||
|
grids = [torchvision.utils.make_grid(f, nrow=1, padding=0) for f in v]
|
||||||
|
_ = torch.stack(grids, dim=0)
|
||||||
|
|
||||||
|
# ⑧ save_results
|
||||||
|
with WallTimer("save_results", rec):
|
||||||
|
if pred_videos_0 is not None:
|
||||||
|
save_results(pred_videos_0.cpu(),
|
||||||
|
os.path.join(tmp_dir, f"dm_{itr}.mp4"),
|
||||||
|
fps=args.save_fps)
|
||||||
|
save_results(pred_videos_1.cpu(),
|
||||||
|
os.path.join(tmp_dir, f"wm_{itr}.mp4"),
|
||||||
|
fps=args.save_fps)
|
||||||
|
|
||||||
|
# ⑨ cpu_transfer
|
||||||
|
with CudaTimer("cpu_transfer", rec):
|
||||||
|
_ = pred_videos_1[:, :, :args.exe_steps].cpu()
|
||||||
|
|
||||||
|
# ── itr_total end ──
|
||||||
|
itr_end.record()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
itr_total_ms = itr_start.elapsed_time(itr_end)
|
||||||
|
rec["itr_total"].append(itr_total_ms)
|
||||||
|
|
||||||
|
# Flatten: each stage has exactly one entry per itr
|
||||||
|
itr_rec = {k: v[0] for k, v in rec.items()}
|
||||||
|
all_records.append(itr_rec)
|
||||||
|
|
||||||
|
# Print live progress
|
||||||
|
print(f" itr {itr}: {itr_total_ms:.0f} ms total | "
|
||||||
|
f"policy={itr_rec.get('synth_policy', 0):.0f} | "
|
||||||
|
f"wm={itr_rec.get('synth_world_model', 0):.0f} | "
|
||||||
|
f"save={itr_rec.get('save_results', 0):.0f} | "
|
||||||
|
f"tb={itr_rec.get('tensorboard_log', 0):.0f}")
|
||||||
|
|
||||||
|
return all_records
|
||||||
|
|
||||||
|
|
||||||
|
# ──────────────────────────────────────────────────────────────────────
|
||||||
|
# Layer 1: Console report
|
||||||
|
# ──────────────────────────────────────────────────────────────────────
|
||||||
|
def print_iteration_report(all_records, warmup=1):
|
||||||
|
"""Print a structured table of per-stage timing across iterations."""
|
||||||
|
if len(all_records) <= warmup:
|
||||||
|
records = all_records
|
||||||
|
else:
|
||||||
|
records = all_records[warmup:]
|
||||||
|
print(f"\n(Skipping first {warmup} itr(s) as warmup)\n")
|
||||||
|
|
||||||
|
# Collect all stage keys in a stable order
|
||||||
|
all_keys = []
|
||||||
|
seen = set()
|
||||||
|
for rec in records:
|
||||||
|
for k in rec:
|
||||||
|
if k not in seen:
|
||||||
|
all_keys.append(k)
|
||||||
|
seen.add(k)
|
||||||
|
|
||||||
|
# Separate top-level stages from sub-stages
|
||||||
|
top_keys = [k for k in all_keys if '/' not in k]
|
||||||
|
sub_keys = [k for k in all_keys if '/' in k]
|
||||||
|
|
||||||
|
def _print_table(keys, title):
|
||||||
|
if not keys:
|
||||||
|
return
|
||||||
|
print("=" * 82)
|
||||||
|
print(title)
|
||||||
|
print("=" * 82)
|
||||||
|
print(f"{'Stage':<35} {'Mean(ms)':>10} {'Std':>8} {'Min':>10} {'Max':>10} {'%':>7}")
|
||||||
|
print("-" * 82)
|
||||||
|
|
||||||
|
total_mean = np.mean([rec.get("itr_total", 0) for rec in records])
|
||||||
|
for k in keys:
|
||||||
|
vals = [rec.get(k, 0) for rec in records]
|
||||||
|
mean = np.mean(vals)
|
||||||
|
std = np.std(vals)
|
||||||
|
mn = np.min(vals)
|
||||||
|
mx = np.max(vals)
|
||||||
|
pct = mean / total_mean * 100 if total_mean > 0 else 0
|
||||||
|
print(f"{k:<35} {mean:>10.1f} {std:>8.1f} {mn:>10.1f} {mx:>10.1f} {pct:>6.1f}%")
|
||||||
|
print("-" * 82)
|
||||||
|
print()
|
||||||
|
|
||||||
|
_print_table(top_keys, "TABLE 1: ITERATION-LEVEL BREAKDOWN")
|
||||||
|
_print_table(sub_keys, "TABLE 2: SYNTHESIS SUB-STAGE BREAKDOWN")
|
||||||
|
|
||||||
|
|
||||||
|
# ──────────────────────────────────────────────────────────────────────
|
||||||
|
# Layer 3: CSV output for A/B comparison
|
||||||
|
# ──────────────────────────────────────────────────────────────────────
|
||||||
|
def write_csv(all_records, csv_path, warmup=1):
|
||||||
|
"""Write per-iteration timing to CSV for later comparison."""
|
||||||
|
records = all_records[warmup:] if len(all_records) > warmup else all_records
|
||||||
|
|
||||||
|
# Collect all keys
|
||||||
|
all_keys = []
|
||||||
|
seen = set()
|
||||||
|
for rec in records:
|
||||||
|
for k in rec:
|
||||||
|
if k not in seen:
|
||||||
|
all_keys.append(k)
|
||||||
|
seen.add(k)
|
||||||
|
|
||||||
|
with open(csv_path, 'w', newline='') as f:
|
||||||
|
writer = csv.DictWriter(f, fieldnames=['itr'] + all_keys)
|
||||||
|
writer.writeheader()
|
||||||
|
for i, rec in enumerate(records):
|
||||||
|
row = {'itr': i}
|
||||||
|
row.update({k: f"{rec.get(k, 0):.2f}" for k in all_keys})
|
||||||
|
writer.writerow(row)
|
||||||
|
|
||||||
|
# Also write a summary row
|
||||||
|
summary_path = csv_path.replace('.csv', '_summary.csv')
|
||||||
|
with open(summary_path, 'w', newline='') as f:
|
||||||
|
writer = csv.DictWriter(f, fieldnames=['stat'] + all_keys)
|
||||||
|
writer.writeheader()
|
||||||
|
for stat_name, stat_fn in [('mean', np.mean), ('std', np.std),
|
||||||
|
('min', np.min), ('max', np.max)]:
|
||||||
|
row = {'stat': stat_name}
|
||||||
|
row.update({k: f"{stat_fn([r.get(k, 0) for r in records]):.2f}"
|
||||||
|
for k in all_keys})
|
||||||
|
writer.writerow(row)
|
||||||
|
|
||||||
|
print(f">>> CSV written to: {csv_path}")
|
||||||
|
print(f">>> Summary written to: {summary_path}")
|
||||||
|
|
||||||
|
|
||||||
|
def compare_csvs(path_a, path_b):
|
||||||
|
"""Compare two summary CSVs and print a diff table."""
|
||||||
|
df_a = pd.read_csv(path_a, index_col='stat')
|
||||||
|
df_b = pd.read_csv(path_b, index_col='stat')
|
||||||
|
|
||||||
|
# Use mean row for comparison
|
||||||
|
mean_a = df_a.loc['mean'].astype(float)
|
||||||
|
mean_b = df_b.loc['mean'].astype(float)
|
||||||
|
|
||||||
|
print("=" * 90)
|
||||||
|
print(f"A/B COMPARISON: {os.path.basename(path_a)} vs {os.path.basename(path_b)}")
|
||||||
|
print("=" * 90)
|
||||||
|
print(f"{'Stage':<35} {'A(ms)':>10} {'B(ms)':>10} {'Diff':>10} {'Speedup':>10}")
|
||||||
|
print("-" * 90)
|
||||||
|
|
||||||
|
for col in mean_a.index:
|
||||||
|
if col not in mean_b.index:
|
||||||
|
continue
|
||||||
|
a_val = mean_a[col]
|
||||||
|
b_val = mean_b[col]
|
||||||
|
diff = b_val - a_val
|
||||||
|
speedup = a_val / b_val if b_val > 0 else float('inf')
|
||||||
|
marker = " <<<" if abs(diff) > 50 else ""
|
||||||
|
print(f"{col:<35} {a_val:>10.1f} {b_val:>10.1f} {diff:>+10.1f} {speedup:>9.2f}x{marker}")
|
||||||
|
|
||||||
|
print("-" * 90)
|
||||||
|
total_a = mean_a.get('itr_total', 0)
|
||||||
|
total_b = mean_b.get('itr_total', 0)
|
||||||
|
print(f"{'itr_total':<35} {total_a:>10.1f} {total_b:>10.1f} "
|
||||||
|
f"{total_b - total_a:>+10.1f} {total_a / total_b if total_b > 0 else 0:>9.2f}x")
|
||||||
|
print()
|
||||||
|
|
||||||
|
|
||||||
|
# ──────────────────────────────────────────────────────────────────────
|
||||||
|
# Layer 2: GPU timeline trace wrapper
|
||||||
|
# ──────────────────────────────────────────────────────────────────────
|
||||||
|
def run_with_trace(model, args, config, noise_shape, device):
|
||||||
|
"""Run iterations under torch.profiler to generate Chrome/TensorBoard traces."""
|
||||||
|
trace_dir = args.trace_dir
|
||||||
|
os.makedirs(trace_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# We need the same data setup as run_profiled_iterations
|
||||||
|
csv_path = os.path.join(args.prompt_dir, f"{args.dataset}.csv")
|
||||||
|
df = pd.read_csv(csv_path)
|
||||||
|
sample = df.iloc[0]
|
||||||
|
|
||||||
|
data_module = instantiate_from_config(config.data)
|
||||||
|
data_module.setup()
|
||||||
|
|
||||||
|
init_frame_path = get_init_frame_path(args.prompt_dir, sample)
|
||||||
|
ori_fps = float(sample['fps'])
|
||||||
|
fs = args.frame_stride
|
||||||
|
model_input_fs = ori_fps // fs
|
||||||
|
|
||||||
|
transition_path = get_transition_path(args.prompt_dir, sample)
|
||||||
|
with h5py.File(transition_path, 'r') as h5f:
|
||||||
|
transition_dict = {}
|
||||||
|
for key in h5f.keys():
|
||||||
|
transition_dict[key] = torch.tensor(h5f[key][()])
|
||||||
|
for key in h5f.attrs.keys():
|
||||||
|
transition_dict[key] = h5f.attrs[key]
|
||||||
|
|
||||||
|
batch, ori_state_dim, ori_action_dim = prepare_init_input(
|
||||||
|
0, init_frame_path, transition_dict, fs,
|
||||||
|
data_module.test_datasets[args.dataset],
|
||||||
|
n_obs_steps=model.n_obs_steps_imagen)
|
||||||
|
|
||||||
|
observation = {
|
||||||
|
'observation.images.top':
|
||||||
|
batch['observation.image'].permute(1, 0, 2, 3)[-1].unsqueeze(0),
|
||||||
|
'observation.state':
|
||||||
|
batch['observation.state'][-1].unsqueeze(0),
|
||||||
|
'action':
|
||||||
|
torch.zeros_like(batch['action'][-1]).unsqueeze(0),
|
||||||
|
}
|
||||||
|
observation = {k: v.to(device, non_blocking=True) for k, v in observation.items()}
|
||||||
|
|
||||||
|
cond_obs_queues = {
|
||||||
|
"observation.images.top": deque(maxlen=model.n_obs_steps_imagen),
|
||||||
|
"observation.state": deque(maxlen=model.n_obs_steps_imagen),
|
||||||
|
"action": deque(maxlen=args.video_length),
|
||||||
|
}
|
||||||
|
cond_obs_queues = populate_queues(cond_obs_queues, observation)
|
||||||
|
|
||||||
|
tmp_dir = os.path.join(args.savedir, "profile_tmp")
|
||||||
|
os.makedirs(tmp_dir, exist_ok=True)
|
||||||
|
prompt_text = sample['instruction']
|
||||||
|
|
||||||
|
# Total iterations: warmup + active
|
||||||
|
n_warmup = 1
|
||||||
|
n_active = min(args.n_iter, 2) # trace 2 active iterations max
|
||||||
|
n_total = n_warmup + n_active
|
||||||
|
|
||||||
|
print(f">>> GPU trace: {n_warmup} warmup + {n_active} active iterations")
|
||||||
|
print(f">>> Trace output: {trace_dir}")
|
||||||
|
|
||||||
|
with torch.no_grad(), torch.profiler.profile(
|
||||||
|
activities=[
|
||||||
|
torch.profiler.ProfilerActivity.CPU,
|
||||||
|
torch.profiler.ProfilerActivity.CUDA,
|
||||||
|
],
|
||||||
|
schedule=torch.profiler.schedule(
|
||||||
|
wait=0, warmup=n_warmup, active=n_active, repeat=1),
|
||||||
|
on_trace_ready=torch.profiler.tensorboard_trace_handler(trace_dir),
|
||||||
|
record_shapes=True,
|
||||||
|
with_stack=True,
|
||||||
|
) as prof:
|
||||||
|
for itr_idx in range(n_total):
|
||||||
|
phase = "warmup" if itr_idx < n_warmup else "active"
|
||||||
|
print(f" trace itr {itr_idx} ({phase})...")
|
||||||
|
|
||||||
|
# ── One full iteration (same logic as run_inference) ──
|
||||||
|
obs_loc = {
|
||||||
|
'observation.images.top':
|
||||||
|
torch.stack(list(cond_obs_queues['observation.images.top']),
|
||||||
|
dim=1).permute(0, 2, 1, 3, 4),
|
||||||
|
'observation.state':
|
||||||
|
torch.stack(list(cond_obs_queues['observation.state']), dim=1),
|
||||||
|
'action':
|
||||||
|
torch.stack(list(cond_obs_queues['action']), dim=1),
|
||||||
|
}
|
||||||
|
obs_loc = {k: v.to(device) for k, v in obs_loc.items()}
|
||||||
|
|
||||||
|
# Policy pass
|
||||||
|
dummy_rec = defaultdict(list)
|
||||||
|
pv0, pa, _ = profiled_synthesis(
|
||||||
|
model, prompt_text, obs_loc, noise_shape,
|
||||||
|
ddim_steps=args.ddim_steps, ddim_eta=args.ddim_eta,
|
||||||
|
unconditional_guidance_scale=args.unconditional_guidance_scale,
|
||||||
|
fs=model_input_fs, text_input=True,
|
||||||
|
timestep_spacing=args.timestep_spacing,
|
||||||
|
guidance_rescale=args.guidance_rescale,
|
||||||
|
sim_mode=False,
|
||||||
|
decode_video=not args.fast_policy_no_decode,
|
||||||
|
records=dummy_rec, prefix="policy")
|
||||||
|
|
||||||
|
for idx in range(len(pa[0])):
|
||||||
|
oa = {'action': pa[0][idx:idx + 1]}
|
||||||
|
oa['action'][:, ori_action_dim:] = 0.0
|
||||||
|
populate_queues(cond_obs_queues, oa)
|
||||||
|
|
||||||
|
# Re-stack for world model
|
||||||
|
obs_loc2 = {
|
||||||
|
'observation.images.top':
|
||||||
|
torch.stack(list(cond_obs_queues['observation.images.top']),
|
||||||
|
dim=1).permute(0, 2, 1, 3, 4),
|
||||||
|
'observation.state':
|
||||||
|
torch.stack(list(cond_obs_queues['observation.state']), dim=1),
|
||||||
|
'action':
|
||||||
|
torch.stack(list(cond_obs_queues['action']), dim=1),
|
||||||
|
}
|
||||||
|
obs_loc2 = {k: v.to(device) for k, v in obs_loc2.items()}
|
||||||
|
|
||||||
|
# World model pass
|
||||||
|
pv1, _, ps = profiled_synthesis(
|
||||||
|
model, "", obs_loc2, noise_shape,
|
||||||
|
ddim_steps=args.ddim_steps, ddim_eta=args.ddim_eta,
|
||||||
|
unconditional_guidance_scale=args.unconditional_guidance_scale,
|
||||||
|
fs=model_input_fs, text_input=False,
|
||||||
|
timestep_spacing=args.timestep_spacing,
|
||||||
|
guidance_rescale=args.guidance_rescale,
|
||||||
|
sim_mode=True, decode_video=True,
|
||||||
|
records=dummy_rec, prefix="wm")
|
||||||
|
|
||||||
|
# Update obs queue
|
||||||
|
for idx in range(args.exe_steps):
|
||||||
|
ou = {
|
||||||
|
'observation.images.top':
|
||||||
|
pv1[0][:, idx:idx + 1].permute(1, 0, 2, 3),
|
||||||
|
'observation.state': ps[0][idx:idx + 1],
|
||||||
|
'action': torch.zeros_like(pa[0][-1:]),
|
||||||
|
}
|
||||||
|
ou['observation.state'][:, ori_state_dim:] = 0.0
|
||||||
|
populate_queues(cond_obs_queues, ou)
|
||||||
|
|
||||||
|
# Save results (captures CPU stall in trace)
|
||||||
|
if pv0 is not None:
|
||||||
|
save_results(pv0.cpu(),
|
||||||
|
os.path.join(tmp_dir, f"trace_dm_{itr_idx}.mp4"),
|
||||||
|
fps=args.save_fps)
|
||||||
|
save_results(pv1.cpu(),
|
||||||
|
os.path.join(tmp_dir, f"trace_wm_{itr_idx}.mp4"),
|
||||||
|
fps=args.save_fps)
|
||||||
|
|
||||||
|
prof.step()
|
||||||
|
|
||||||
|
print(f">>> Trace saved to {trace_dir}")
|
||||||
|
print(" View with: tensorboard --logdir", trace_dir)
|
||||||
|
print(" Or open the .json file in chrome://tracing")
|
||||||
|
|
||||||
|
|
||||||
|
# ──────────────────────────────────────────────────────────────────────
|
||||||
|
# Argument parser
|
||||||
|
# ──────────────────────────────────────────────────────────────────────
|
||||||
|
def get_parser():
|
||||||
|
p = argparse.ArgumentParser(description="Profile full iteration loop")
|
||||||
|
|
||||||
|
# Compare mode (no model needed)
|
||||||
|
p.add_argument("--compare", nargs=2, metavar=("A_SUMMARY", "B_SUMMARY"),
|
||||||
|
help="Compare two summary CSVs and exit")
|
||||||
|
|
||||||
|
# Model / data
|
||||||
|
p.add_argument("--ckpt_path", type=str, default=None)
|
||||||
|
p.add_argument("--config", type=str, default=None)
|
||||||
|
p.add_argument("--prompt_dir", type=str, default=None)
|
||||||
|
p.add_argument("--dataset", type=str, default=None)
|
||||||
|
p.add_argument("--savedir", type=str, default="profile_output")
|
||||||
|
|
||||||
|
# Inference params (match world_model_interaction.py)
|
||||||
|
p.add_argument("--ddim_steps", type=int, default=50)
|
||||||
|
p.add_argument("--ddim_eta", type=float, default=1.0)
|
||||||
|
p.add_argument("--bs", type=int, default=1)
|
||||||
|
p.add_argument("--height", type=int, default=320)
|
||||||
|
p.add_argument("--width", type=int, default=512)
|
||||||
|
p.add_argument("--frame_stride", type=int, default=4)
|
||||||
|
p.add_argument("--unconditional_guidance_scale", type=float, default=1.0)
|
||||||
|
p.add_argument("--video_length", type=int, default=16)
|
||||||
|
p.add_argument("--timestep_spacing", type=str, default="uniform_trailing")
|
||||||
|
p.add_argument("--guidance_rescale", type=float, default=0.7)
|
||||||
|
p.add_argument("--exe_steps", type=int, default=16)
|
||||||
|
p.add_argument("--n_iter", type=int, default=5)
|
||||||
|
p.add_argument("--save_fps", type=int, default=8)
|
||||||
|
p.add_argument("--seed", type=int, default=123)
|
||||||
|
p.add_argument("--perframe_ae", action='store_true', default=False)
|
||||||
|
p.add_argument("--vae_dtype", type=str, choices=["fp32", "bf16"], default="bf16")
|
||||||
|
p.add_argument("--fast_policy_no_decode", action='store_true', default=False)
|
||||||
|
|
||||||
|
# Profiling control
|
||||||
|
p.add_argument("--warmup", type=int, default=1,
|
||||||
|
help="Number of warmup iterations to skip in statistics")
|
||||||
|
p.add_argument("--csv", type=str, default=None,
|
||||||
|
help="Write per-iteration timing to this CSV file")
|
||||||
|
p.add_argument("--trace", action='store_true', default=False,
|
||||||
|
help="Enable Layer 2: GPU timeline trace")
|
||||||
|
p.add_argument("--trace_dir", type=str, default="./profile_traces",
|
||||||
|
help="Directory for trace output")
|
||||||
|
|
||||||
|
return p
|
||||||
|
|
||||||
|
|
||||||
|
# ──────────────────────────────────────────────────────────────────────
|
||||||
|
# Main
|
||||||
|
# ──────────────────────────────────────────────────────────────────────
|
||||||
|
def main():
|
||||||
|
patch_norm_bypass_autocast()
|
||||||
|
parser = get_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# ── Compare mode: no model needed ──
|
||||||
|
if args.compare:
|
||||||
|
compare_csvs(args.compare[0], args.compare[1])
|
||||||
|
return
|
||||||
|
|
||||||
|
# ── Validate required args ──
|
||||||
|
for required in ['ckpt_path', 'config', 'prompt_dir', 'dataset']:
|
||||||
|
if getattr(args, required) is None:
|
||||||
|
parser.error(f"--{required} is required for profiling mode")
|
||||||
|
|
||||||
|
seed_everything(args.seed)
|
||||||
|
os.makedirs(args.savedir, exist_ok=True)
|
||||||
|
|
||||||
|
# ── Load model ──
|
||||||
|
print("=" * 60)
|
||||||
|
print("PROFILE ITERATION — Loading model...")
|
||||||
|
print("=" * 60)
|
||||||
|
model, config = load_model(args)
|
||||||
|
device = next(model.parameters()).device
|
||||||
|
|
||||||
|
h, w = args.height // 8, args.width // 8
|
||||||
|
channels = model.model.diffusion_model.out_channels
|
||||||
|
noise_shape = [args.bs, channels, args.video_length, h, w]
|
||||||
|
print(f">>> Noise shape: {noise_shape}")
|
||||||
|
print(f">>> DDIM steps: {args.ddim_steps}")
|
||||||
|
print(f">>> fast_policy_no_decode: {args.fast_policy_no_decode}")
|
||||||
|
|
||||||
|
# ── Layer 2: GPU trace (optional) ──
|
||||||
|
if args.trace:
|
||||||
|
with torch.no_grad():
|
||||||
|
run_with_trace(model, args, config, noise_shape, device)
|
||||||
|
print()
|
||||||
|
|
||||||
|
# ── Layer 1: Iteration-level breakdown ──
|
||||||
|
print("=" * 60)
|
||||||
|
print("LAYER 1: ITERATION-LEVEL PROFILING")
|
||||||
|
print("=" * 60)
|
||||||
|
with torch.no_grad():
|
||||||
|
all_records = run_profiled_iterations(
|
||||||
|
model, args, config, noise_shape, device)
|
||||||
|
|
||||||
|
# Print report
|
||||||
|
print_iteration_report(all_records, warmup=args.warmup)
|
||||||
|
|
||||||
|
# ── Layer 3: CSV output for A/B comparison ──
|
||||||
|
if args.csv:
|
||||||
|
write_csv(all_records, args.csv, warmup=args.warmup)
|
||||||
|
|
||||||
|
print("Done.")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
5
unitree_z1_dual_arm_cleanup_pencils/case1/run_profile.sh
Normal file
5
unitree_z1_dual_arm_cleanup_pencils/case1/run_profile.sh
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
#\!/bin/bash
|
||||||
|
res_dir="unitree_z1_dual_arm_cleanup_pencils/case1"
|
||||||
|
dataset="unitree_z1_dual_arm_cleanup_pencils"
|
||||||
|
|
||||||
|
TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 CUDA_VISIBLE_DEVICES=0 python3 scripts/evaluation/profile_iteration.py --seed 123 --ckpt_path ckpts/unifolm_wma_dual_mix_bf16.ckpt --config configs/inference/world_model_interaction.yaml --savedir "${res_dir}/profile_output" --prompt_dir "${res_dir}/world_model_interaction_prompts" --dataset ${dataset} --bs 1 --height 320 --width 512 --unconditional_guidance_scale 1.0 --ddim_steps 50 --ddim_eta 1.0 --video_length 16 --frame_stride 4 --exe_steps 16 --n_iter 5 --warmup 1 --timestep_spacing uniform_trailing --guidance_rescale 0.7 --perframe_ae --vae_dtype bf16 --fast_policy_no_decode --csv "${res_dir}/profile_output/baseline.csv" 2>&1 | tee "${res_dir}/profile_output/profile.log"
|
||||||
Reference in New Issue
Block a user