remove profile

This commit is contained in:
qhy
2026-02-10 11:28:26 +08:00
parent ff920b85a2
commit f1f92072e6

View File

@@ -9,12 +9,9 @@ import logging
import einops
import warnings
import imageio
import time
import json
import atexit
from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass, field, asdict
from contextlib import nullcontext
from typing import Optional, Dict, List, Any, Mapping
from pytorch_lightning import seed_everything
@@ -26,375 +23,12 @@ from torch import nn
from eval_utils import populate_queues
from collections import deque
from torch import Tensor
from torch.utils.tensorboard import SummaryWriter
from PIL import Image
from unifolm_wma.models.samplers.ddim import DDIMSampler
from unifolm_wma.utils.utils import instantiate_from_config
# ========== Profiling Infrastructure ==========
@dataclass
class TimingRecord:
"""Record for a single timing measurement."""
name: str
start_time: float = 0.0
end_time: float = 0.0
cuda_time_ms: float = 0.0
count: int = 0
children: List['TimingRecord'] = field(default_factory=list)
@property
def cpu_time_ms(self) -> float:
return (self.end_time - self.start_time) * 1000
def to_dict(self) -> dict:
return {
'name': self.name,
'cpu_time_ms': self.cpu_time_ms,
'cuda_time_ms': self.cuda_time_ms,
'count': self.count,
'children': [c.to_dict() for c in self.children]
}
class ProfilerManager:
"""Manages macro and micro-level profiling."""
def __init__(
self,
enabled: bool = False,
output_dir: str = "./profile_output",
profile_detail: str = "light",
):
self.enabled = enabled
self.output_dir = output_dir
self.profile_detail = profile_detail
self.macro_timings: Dict[str, List[float]] = {}
self.cuda_events: Dict[str, List[tuple]] = {}
self.memory_snapshots: List[Dict] = []
self.pytorch_profiler = None
self.current_iteration = 0
self.operator_stats: Dict[str, Dict] = {}
self.profiler_config = self._build_profiler_config(profile_detail)
if enabled:
os.makedirs(output_dir, exist_ok=True)
def _build_profiler_config(self, profile_detail: str) -> Dict[str, Any]:
"""Return profiler settings based on the requested detail level."""
if profile_detail not in ("light", "full"):
raise ValueError(f"Unsupported profile_detail: {profile_detail}")
if profile_detail == "full":
return {
"record_shapes": True,
"profile_memory": True,
"with_stack": True,
"with_flops": True,
"with_modules": True,
"group_by_input_shape": True,
}
return {
"record_shapes": False,
"profile_memory": False,
"with_stack": False,
"with_flops": False,
"with_modules": False,
"group_by_input_shape": False,
}
@contextmanager
def profile_section(self, name: str, sync_cuda: bool = True):
"""Context manager for profiling a code section."""
if not self.enabled:
yield
return
if sync_cuda and torch.cuda.is_available():
torch.cuda.synchronize()
start_event = None
end_event = None
if torch.cuda.is_available():
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
start_time = time.perf_counter()
try:
yield
finally:
if sync_cuda and torch.cuda.is_available():
torch.cuda.synchronize()
end_time = time.perf_counter()
cpu_time_ms = (end_time - start_time) * 1000
cuda_time_ms = 0.0
if start_event is not None and end_event is not None:
end_event.record()
torch.cuda.synchronize()
cuda_time_ms = start_event.elapsed_time(end_event)
if name not in self.macro_timings:
self.macro_timings[name] = []
self.macro_timings[name].append(cpu_time_ms)
if name not in self.cuda_events:
self.cuda_events[name] = []
self.cuda_events[name].append((cpu_time_ms, cuda_time_ms))
def record_memory(self, tag: str = ""):
"""Record current GPU memory state."""
if not self.enabled or not torch.cuda.is_available():
return
snapshot = {
'tag': tag,
'iteration': self.current_iteration,
'allocated_mb': torch.cuda.memory_allocated() / 1024**2,
'reserved_mb': torch.cuda.memory_reserved() / 1024**2,
'max_allocated_mb': torch.cuda.max_memory_allocated() / 1024**2,
}
self.memory_snapshots.append(snapshot)
def start_pytorch_profiler(self, wait: int = 1, warmup: int = 1, active: int = 3):
"""Start PyTorch profiler for operator-level analysis."""
if not self.enabled:
return nullcontext()
self.pytorch_profiler = torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
schedule=torch.profiler.schedule(
wait=wait, warmup=warmup, active=active, repeat=1
),
on_trace_ready=self._trace_handler,
record_shapes=self.profiler_config["record_shapes"],
profile_memory=self.profiler_config["profile_memory"],
with_stack=self.profiler_config["with_stack"],
with_flops=self.profiler_config["with_flops"],
with_modules=self.profiler_config["with_modules"],
)
return self.pytorch_profiler
def _trace_handler(self, prof):
"""Handle profiler trace output."""
trace_path = os.path.join(
self.output_dir,
f"trace_iter_{self.current_iteration}.json"
)
prof.export_chrome_trace(trace_path)
# Extract operator statistics
key_averages = prof.key_averages(
group_by_input_shape=self.profiler_config["group_by_input_shape"]
)
for evt in key_averages:
op_name = evt.key
if op_name not in self.operator_stats:
self.operator_stats[op_name] = {
'count': 0,
'cpu_time_total_us': 0,
'cuda_time_total_us': 0,
'self_cpu_time_total_us': 0,
'self_cuda_time_total_us': 0,
'cpu_memory_usage': 0,
'cuda_memory_usage': 0,
'flops': 0,
}
stats = self.operator_stats[op_name]
stats['count'] += evt.count
stats['cpu_time_total_us'] += evt.cpu_time_total
stats['cuda_time_total_us'] += evt.cuda_time_total
stats['self_cpu_time_total_us'] += evt.self_cpu_time_total
stats['self_cuda_time_total_us'] += evt.self_cuda_time_total
if hasattr(evt, 'cpu_memory_usage'):
stats['cpu_memory_usage'] += evt.cpu_memory_usage
if hasattr(evt, 'cuda_memory_usage'):
stats['cuda_memory_usage'] += evt.cuda_memory_usage
if hasattr(evt, 'flops') and evt.flops:
stats['flops'] += evt.flops
def step_profiler(self):
"""Step the PyTorch profiler."""
if self.pytorch_profiler is not None:
self.pytorch_profiler.step()
def generate_report(self) -> str:
"""Generate comprehensive profiling report."""
if not self.enabled:
return "Profiling disabled."
report_lines = []
report_lines.append("=" * 80)
report_lines.append("PERFORMANCE PROFILING REPORT")
report_lines.append("=" * 80)
report_lines.append("")
# Macro-level timing summary
report_lines.append("-" * 40)
report_lines.append("MACRO-LEVEL TIMING SUMMARY")
report_lines.append("-" * 40)
report_lines.append(f"{'Section':<40} {'Count':>8} {'Total(ms)':>12} {'Avg(ms)':>12} {'CUDA Avg(ms)':>14}")
report_lines.append("-" * 86)
total_time = 0
timing_data = []
for name, times in sorted(self.macro_timings.items()):
cuda_times = [ct for _, ct in self.cuda_events.get(name, [])]
avg_time = np.mean(times)
avg_cuda = np.mean(cuda_times) if cuda_times else 0
total = sum(times)
total_time += total
timing_data.append({
'name': name,
'count': len(times),
'total_ms': total,
'avg_ms': avg_time,
'cuda_avg_ms': avg_cuda,
'times': times,
'cuda_times': cuda_times,
})
report_lines.append(f"{name:<40} {len(times):>8} {total:>12.2f} {avg_time:>12.2f} {avg_cuda:>14.2f}")
report_lines.append("-" * 86)
report_lines.append(f"{'TOTAL':<40} {'':<8} {total_time:>12.2f}")
report_lines.append("")
# Memory summary
if self.memory_snapshots:
report_lines.append("-" * 40)
report_lines.append("GPU MEMORY SUMMARY")
report_lines.append("-" * 40)
max_alloc = max(s['max_allocated_mb'] for s in self.memory_snapshots)
avg_alloc = np.mean([s['allocated_mb'] for s in self.memory_snapshots])
report_lines.append(f"Peak allocated: {max_alloc:>10.2f} MB")
report_lines.append(f"Average allocated: {avg_alloc:>10.2f} MB")
report_lines.append("")
# Top operators by CUDA time
if self.operator_stats:
report_lines.append("-" * 40)
report_lines.append("TOP 30 OPERATORS BY CUDA TIME")
report_lines.append("-" * 40)
sorted_ops = sorted(
self.operator_stats.items(),
key=lambda x: x[1]['cuda_time_total_us'],
reverse=True
)[:30]
report_lines.append(f"{'Operator':<50} {'Count':>8} {'CUDA(ms)':>12} {'CPU(ms)':>12} {'Self CUDA(ms)':>14}")
report_lines.append("-" * 96)
for op_name, stats in sorted_ops:
# Truncate long operator names
display_name = op_name[:47] + "..." if len(op_name) > 50 else op_name
report_lines.append(
f"{display_name:<50} {stats['count']:>8} "
f"{stats['cuda_time_total_us']/1000:>12.2f} "
f"{stats['cpu_time_total_us']/1000:>12.2f} "
f"{stats['self_cuda_time_total_us']/1000:>14.2f}"
)
report_lines.append("")
# Compute category breakdown
report_lines.append("-" * 40)
report_lines.append("OPERATOR CATEGORY BREAKDOWN")
report_lines.append("-" * 40)
categories = {
'Attention': ['attention', 'softmax', 'bmm', 'baddbmm'],
'Convolution': ['conv', 'cudnn'],
'Normalization': ['norm', 'layer_norm', 'batch_norm', 'group_norm'],
'Activation': ['relu', 'gelu', 'silu', 'sigmoid', 'tanh'],
'Linear/GEMM': ['linear', 'addmm', 'mm', 'gemm'],
'Memory': ['copy', 'contiguous', 'view', 'reshape', 'permute', 'transpose'],
'Elementwise': ['add', 'mul', 'div', 'sub', 'pow', 'exp', 'sqrt'],
}
category_times = {cat: 0.0 for cat in categories}
category_times['Other'] = 0.0
for op_name, stats in self.operator_stats.items():
op_lower = op_name.lower()
categorized = False
for cat, keywords in categories.items():
if any(kw in op_lower for kw in keywords):
category_times[cat] += stats['cuda_time_total_us']
categorized = True
break
if not categorized:
category_times['Other'] += stats['cuda_time_total_us']
total_op_time = sum(category_times.values())
report_lines.append(f"{'Category':<30} {'CUDA Time(ms)':>15} {'Percentage':>12}")
report_lines.append("-" * 57)
for cat, time_us in sorted(category_times.items(), key=lambda x: -x[1]):
pct = (time_us / total_op_time * 100) if total_op_time > 0 else 0
report_lines.append(f"{cat:<30} {time_us/1000:>15.2f} {pct:>11.1f}%")
report_lines.append("")
report = "\n".join(report_lines)
return report
def save_results(self):
"""Save all profiling results to files."""
if not self.enabled:
return
# Save report
report = self.generate_report()
report_path = os.path.join(self.output_dir, "profiling_report.txt")
with open(report_path, 'w') as f:
f.write(report)
print(f">>> Profiling report saved to: {report_path}")
# Save detailed JSON data
data = {
'macro_timings': {
name: {
'times': times,
'cuda_times': [ct for _, ct in self.cuda_events.get(name, [])]
}
for name, times in self.macro_timings.items()
},
'memory_snapshots': self.memory_snapshots,
'operator_stats': self.operator_stats,
}
json_path = os.path.join(self.output_dir, "profiling_data.json")
with open(json_path, 'w') as f:
json.dump(data, f, indent=2)
print(f">>> Detailed profiling data saved to: {json_path}")
# Print summary to console
print("\n" + report)
# Global profiler instance
_profiler: Optional[ProfilerManager] = None
def get_profiler() -> ProfilerManager:
"""Get the global profiler instance."""
global _profiler
if _profiler is None:
_profiler = ProfilerManager(enabled=False)
return _profiler
def init_profiler(enabled: bool, output_dir: str, profile_detail: str) -> ProfilerManager:
"""Initialize the global profiler."""
global _profiler
_profiler = ProfilerManager(
enabled=enabled,
output_dir=output_dir,
profile_detail=profile_detail,
)
return _profiler
# ========== Async I/O ==========
_io_executor: Optional[ThreadPoolExecutor] = None
_io_futures: List[Any] = []
@@ -447,28 +81,6 @@ def save_results_async(video: Tensor, filename: str, fps: int = 8) -> None:
_io_futures.append(fut)
def _log_to_tb_sync(writer, video_cpu: Tensor, tag: str, fps: int) -> None:
"""Synchronous TensorBoard log on CPU tensor (runs in background thread)."""
if video_cpu.dim() == 5:
n = video_cpu.shape[0]
video = video_cpu.permute(2, 0, 1, 3, 4)
frame_grids = [
torchvision.utils.make_grid(framesheet, nrow=int(n), padding=0)
for framesheet in video
]
grid = torch.stack(frame_grids, dim=0)
grid = (grid + 1.0) / 2.0
grid = grid.unsqueeze(dim=0)
writer.add_video(tag, grid, fps=fps)
def log_to_tensorboard_async(writer, data: Tensor, tag: str, fps: int = 10) -> None:
"""Submit TensorBoard logging to background thread pool."""
if isinstance(data, torch.Tensor) and data.dim() == 5:
data_cpu = data.detach().cpu()
fut = _get_io_executor().submit(_log_to_tb_sync, writer, data_cpu, tag, fps)
_io_futures.append(fut)
# ========== Original Functions ==========
def get_device_from_parameters(module: nn.Module) -> torch.device:
@@ -590,9 +202,7 @@ def load_model_checkpoint(model: nn.Module,
def maybe_cast_module(module: nn.Module | None,
dtype: torch.dtype,
label: str,
profiler: Optional[ProfilerManager] = None,
profile_name: Optional[str] = None) -> None:
label: str) -> None:
if module is None:
return
try:
@@ -603,10 +213,6 @@ def maybe_cast_module(module: nn.Module | None,
if param.dtype == dtype:
print(f">>> {label} already {dtype}; skip cast")
return
ctx = nullcontext()
if profiler is not None and profile_name:
ctx = profiler.profile_section(profile_name)
with ctx:
module.to(dtype=dtype)
print(f">>> {label} cast to {dtype}")
@@ -825,8 +431,6 @@ def get_latent_z(model, videos: Tensor) -> Tensor:
Returns:
Tensor: Latent video tensor of shape [B, C, T, H, W].
"""
profiler = get_profiler()
with profiler.profile_section("get_latent_z/encode"):
b, c, t, h, w = videos.shape
x = rearrange(videos, 'b c t h w -> (b t) c h w')
vae_ctx = nullcontext()
@@ -941,8 +545,6 @@ def image_guided_synthesis_sim_mode(
actions (torch.Tensor): Predicted action sequences [B, T, D] from diffusion decoding.
states (torch.Tensor): Predicted state sequences [B, T, D] from diffusion decoding.
"""
profiler = get_profiler()
b, _, t, _, _ = noise_shape
ddim_sampler = getattr(model, "_ddim_sampler", None)
if ddim_sampler is None:
@@ -952,7 +554,6 @@ def image_guided_synthesis_sim_mode(
fs = torch.tensor([fs] * batch_size, dtype=torch.long, device=model.device)
with profiler.profile_section("synthesis/conditioning_prep"):
img = observation['observation.images.top'].permute(0, 2, 1, 3, 4)
cond_img = rearrange(img, 'b o c h w -> (b o) c h w')[-1:]
if getattr(model, "encoder_bf16", False) and model.device.type == "cuda":
@@ -1038,7 +639,6 @@ def image_guided_synthesis_sim_mode(
cond_z0 = None
if ddim_sampler is not None:
with profiler.profile_section("synthesis/ddim_sampling"):
autocast_ctx = nullcontext()
if diffusion_autocast_dtype is not None and model.device.type == "cuda":
autocast_ctx = torch.autocast("cuda", dtype=diffusion_autocast_dtype)
@@ -1061,7 +661,6 @@ def image_guided_synthesis_sim_mode(
**kwargs)
# Reconstruct from latent to pixel space
with profiler.profile_section("synthesis/decode_first_stage"):
if getattr(model, "vae_bf16", False):
if samples.dtype != torch.bfloat16:
samples = samples.to(dtype=torch.bfloat16)
@@ -1091,13 +690,8 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
Returns:
None
"""
profiler = get_profiler()
# Create inference and tensorboard dirs
# Create inference dir
os.makedirs(args.savedir + '/inference', exist_ok=True)
log_dir = args.savedir + f"/tensorboard"
os.makedirs(log_dir, exist_ok=True)
writer = SummaryWriter(log_dir=log_dir)
# Load prompt
csv_path = os.path.join(args.prompt_dir, f"{args.dataset}.csv")
@@ -1110,7 +704,6 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
if os.path.exists(prepared_path):
# ---- Fast path: load the fully-prepared model ----
print(f">>> Loading prepared model from {prepared_path} ...")
with profiler.profile_section("model_loading/prepared"):
model = torch.load(prepared_path,
map_location=f"cuda:{gpu_no}",
weights_only=False,
@@ -1122,7 +715,6 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
print(f">>> Prepared model loaded.")
else:
# ---- Normal path: construct + checkpoint + casting ----
with profiler.profile_section("model_loading/config"):
config['model']['params']['wma_config']['params'][
'use_checkpoint'] = False
model = instantiate_from_config(config.model)
@@ -1130,7 +722,6 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
assert os.path.exists(args.ckpt_path), "Error: checkpoint Not Found!"
with profiler.profile_section("model_loading/checkpoint"):
model = load_model_checkpoint(model, args.ckpt_path,
device=f"cuda:{gpu_no}")
model.eval()
@@ -1143,8 +734,6 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
model.model,
torch.bfloat16,
"diffusion backbone",
profiler=profiler,
profile_name="model_loading/diffusion_bf16",
)
diffusion_autocast_dtype = torch.bfloat16
print(">>> diffusion backbone set to bfloat16")
@@ -1155,8 +744,6 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
model.first_stage_model,
vae_weight_dtype,
"VAE",
profiler=profiler,
profile_name="model_loading/vae_cast",
)
model.vae_bf16 = args.vae_dtype == "bf16"
print(f">>> VAE dtype set to {args.vae_dtype}")
@@ -1195,16 +782,12 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
model.cond_stage_model,
encoder_weight_dtype,
"cond_stage_model",
profiler=profiler,
profile_name="model_loading/encoder_cond_cast",
)
if hasattr(model, "embedder") and model.embedder is not None:
maybe_cast_module(
model.embedder,
encoder_weight_dtype,
"embedder",
profiler=profiler,
profile_name="model_loading/encoder_embedder_cast",
)
model.encoder_bf16 = encoder_bf16
model.encoder_mode = encoder_mode
@@ -1220,24 +803,18 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
model.image_proj_model,
projector_weight_dtype,
"image_proj_model",
profiler=profiler,
profile_name="model_loading/projector_image_cast",
)
if hasattr(model, "state_projector") and model.state_projector is not None:
maybe_cast_module(
model.state_projector,
projector_weight_dtype,
"state_projector",
profiler=profiler,
profile_name="model_loading/projector_state_cast",
)
if hasattr(model, "action_projector") and model.action_projector is not None:
maybe_cast_module(
model.action_projector,
projector_weight_dtype,
"action_projector",
profiler=profiler,
profile_name="model_loading/projector_action_cast",
)
if hasattr(model, "projector_bf16"):
model.projector_bf16 = projector_bf16
@@ -1269,14 +846,11 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
# Build normalizer (always needed, independent of model loading path)
logging.info("***** Configing Data *****")
with profiler.profile_section("data_loading"):
data = instantiate_from_config(config.data)
data.setup()
print(">>> Dataset is successfully loaded ...")
device = get_device_from_parameters(model)
profiler.record_memory("after_model_load")
# Run over data
assert (args.height % 16 == 0) and (
args.width % 16
@@ -1290,10 +864,6 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
print(f'>>> Generate {n_frames} frames under each generation ...')
noise_shape = [args.bs, channels, n_frames, h, w]
# Determine profiler iterations
profile_active_iters = getattr(args, 'profile_iterations', 3)
use_pytorch_profiler = profiler.enabled and profile_active_iters > 0
# Start inference
for idx in range(0, len(df)):
sample = df.iloc[idx]
@@ -1309,7 +879,6 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
# Load transitions to get the initial state later
transition_path = get_transition_path(args.prompt_dir, sample)
with profiler.profile_section("load_transitions"):
with h5py.File(transition_path, 'r') as h5f:
transition_dict = {}
for key in h5f.keys():
@@ -1337,7 +906,6 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
}
# Obtain initial frame and state
with profiler.profile_section("prepare_init_input"):
start_idx = 0
model_input_fs = ori_fps // fs
batch, ori_state_dim, ori_action_dim = prepare_init_input(
@@ -1360,24 +928,12 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
# Update observation queues
cond_obs_queues = populate_queues(cond_obs_queues, observation)
# Setup PyTorch profiler context if enabled
pytorch_prof_ctx = nullcontext()
if use_pytorch_profiler:
pytorch_prof_ctx = profiler.start_pytorch_profiler(
wait=1, warmup=1, active=profile_active_iters
)
# Multi-round interaction with the world-model
with pytorch_prof_ctx:
for itr in tqdm(range(args.n_iter)):
log_every = max(1, args.step_log_every)
log_step = (itr % log_every == 0)
profiler.current_iteration = itr
profiler.record_memory(f"iter_{itr}_start")
with profiler.profile_section("iteration_total"):
# Get observation
with profiler.profile_section("prepare_observation"):
observation = {
'observation.images.top':
torch.stack(list(
@@ -1394,7 +950,6 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
# Use world-model in policy to generate action
if log_step:
print(f'>>> Step {itr}: generating actions ...')
with profiler.profile_section("action_generation"):
pred_videos_0, pred_actions, _ = image_guided_synthesis_sim_mode(
model,
sample['instruction'],
@@ -1412,7 +967,6 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
diffusion_autocast_dtype=diffusion_autocast_dtype)
# Update future actions in the observation queues
with profiler.profile_section("update_action_queues"):
for act_idx in range(len(pred_actions[0])):
obs_update = {'action': pred_actions[0][act_idx:act_idx + 1]}
obs_update['action'][:, ori_action_dim:] = 0.0
@@ -1420,7 +974,6 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
obs_update)
# Collect data for interacting the world-model using the predicted actions
with profiler.profile_section("prepare_wm_observation"):
observation = {
'observation.images.top':
torch.stack(list(
@@ -1437,7 +990,6 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
# Interaction with the world-model
if log_step:
print(f'>>> Step {itr}: interacting with world model ...')
with profiler.profile_section("world_model_interaction"):
pred_videos_1, _, pred_states = image_guided_synthesis_sim_mode(
model,
"",
@@ -1454,7 +1006,6 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
guidance_rescale=args.guidance_rescale,
diffusion_autocast_dtype=diffusion_autocast_dtype)
with profiler.profile_section("update_state_queues"):
for step_idx in range(args.exe_steps):
obs_update = {
'observation.images.top':
@@ -1470,20 +1021,6 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
obs_update)
# Save the imagen videos for decision-making (async)
with profiler.profile_section("save_results"):
sample_tag = f"{args.dataset}-vid{sample['videoid']}-dm-fs-{fs}/itr-{itr}"
log_to_tensorboard_async(writer,
pred_videos_0,
sample_tag,
fps=args.save_fps)
# Save videos environment changes via world-model interaction
sample_tag = f"{args.dataset}-vid{sample['videoid']}-wd-fs-{fs}/itr-{itr}"
log_to_tensorboard_async(writer,
pred_videos_1,
sample_tag,
fps=args.save_fps)
# Save the imagen videos for decision-making
sample_video_file = f'{video_save_dir}/dm/{fs}/itr-{itr}.mp4'
save_results_async(pred_videos_0,
sample_video_file,
@@ -1498,24 +1035,13 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
# Collect the result of world-model interactions
wm_video.append(pred_videos_1[:, :, :args.exe_steps].cpu())
profiler.record_memory(f"iter_{itr}_end")
profiler.step_profiler()
full_video = torch.cat(wm_video, dim=2)
sample_tag = f"{args.dataset}-vid{sample['videoid']}-wd-fs-{fs}/full"
log_to_tensorboard_async(writer,
full_video,
sample_tag,
fps=args.save_fps)
sample_full_video_file = f"{video_save_dir}/../{sample['videoid']}_full_fs{fs}.mp4"
save_results_async(full_video, sample_full_video_file, fps=args.save_fps)
# Wait for all async I/O to complete before profiling report
# Wait for all async I/O to complete
_flush_io()
# Save profiling results
profiler.save_results()
def get_parser():
parser = argparse.ArgumentParser()
@@ -1704,32 +1230,6 @@ def get_parser():
type=int,
default=8,
help="fps for the saving video")
# Profiling arguments
parser.add_argument(
"--profile",
action='store_true',
default=False,
help="Enable performance profiling (macro and operator-level analysis)."
)
parser.add_argument(
"--profile_output_dir",
type=str,
default=None,
help="Directory to save profiling results. Defaults to {savedir}/profile_output."
)
parser.add_argument(
"--profile_iterations",
type=int,
default=3,
help="Number of iterations to run PyTorch profiler's active phase for operator-level analysis."
)
parser.add_argument(
"--profile_detail",
type=str,
choices=["light", "full"],
default="light",
help="Profiling detail level. Use 'full' for shapes/stacks/memory/flops."
)
return parser
@@ -1741,15 +1241,5 @@ if __name__ == '__main__':
seed = random.randint(0, 2**31)
seed_everything(seed)
# Initialize profiler
profile_output_dir = args.profile_output_dir
if profile_output_dir is None:
profile_output_dir = os.path.join(args.savedir, "profile_output")
init_profiler(
enabled=args.profile,
output_dir=profile_output_dir,
profile_detail=args.profile_detail,
)
rank, gpu_num = 0, 1
run_inference(args, gpu_num, rank)