remove profile
This commit is contained in:
@@ -9,12 +9,9 @@ import logging
|
||||
import einops
|
||||
import warnings
|
||||
import imageio
|
||||
import time
|
||||
import json
|
||||
import atexit
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from dataclasses import dataclass, field, asdict
|
||||
from contextlib import nullcontext
|
||||
from typing import Optional, Dict, List, Any, Mapping
|
||||
|
||||
from pytorch_lightning import seed_everything
|
||||
@@ -26,375 +23,12 @@ from torch import nn
|
||||
from eval_utils import populate_queues
|
||||
from collections import deque
|
||||
from torch import Tensor
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from PIL import Image
|
||||
|
||||
from unifolm_wma.models.samplers.ddim import DDIMSampler
|
||||
from unifolm_wma.utils.utils import instantiate_from_config
|
||||
|
||||
|
||||
# ========== Profiling Infrastructure ==========
|
||||
@dataclass
|
||||
class TimingRecord:
|
||||
"""Record for a single timing measurement."""
|
||||
name: str
|
||||
start_time: float = 0.0
|
||||
end_time: float = 0.0
|
||||
cuda_time_ms: float = 0.0
|
||||
count: int = 0
|
||||
children: List['TimingRecord'] = field(default_factory=list)
|
||||
|
||||
@property
|
||||
def cpu_time_ms(self) -> float:
|
||||
return (self.end_time - self.start_time) * 1000
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
'name': self.name,
|
||||
'cpu_time_ms': self.cpu_time_ms,
|
||||
'cuda_time_ms': self.cuda_time_ms,
|
||||
'count': self.count,
|
||||
'children': [c.to_dict() for c in self.children]
|
||||
}
|
||||
|
||||
|
||||
class ProfilerManager:
|
||||
"""Manages macro and micro-level profiling."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
enabled: bool = False,
|
||||
output_dir: str = "./profile_output",
|
||||
profile_detail: str = "light",
|
||||
):
|
||||
self.enabled = enabled
|
||||
self.output_dir = output_dir
|
||||
self.profile_detail = profile_detail
|
||||
self.macro_timings: Dict[str, List[float]] = {}
|
||||
self.cuda_events: Dict[str, List[tuple]] = {}
|
||||
self.memory_snapshots: List[Dict] = []
|
||||
self.pytorch_profiler = None
|
||||
self.current_iteration = 0
|
||||
self.operator_stats: Dict[str, Dict] = {}
|
||||
self.profiler_config = self._build_profiler_config(profile_detail)
|
||||
|
||||
if enabled:
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
def _build_profiler_config(self, profile_detail: str) -> Dict[str, Any]:
|
||||
"""Return profiler settings based on the requested detail level."""
|
||||
if profile_detail not in ("light", "full"):
|
||||
raise ValueError(f"Unsupported profile_detail: {profile_detail}")
|
||||
if profile_detail == "full":
|
||||
return {
|
||||
"record_shapes": True,
|
||||
"profile_memory": True,
|
||||
"with_stack": True,
|
||||
"with_flops": True,
|
||||
"with_modules": True,
|
||||
"group_by_input_shape": True,
|
||||
}
|
||||
return {
|
||||
"record_shapes": False,
|
||||
"profile_memory": False,
|
||||
"with_stack": False,
|
||||
"with_flops": False,
|
||||
"with_modules": False,
|
||||
"group_by_input_shape": False,
|
||||
}
|
||||
|
||||
@contextmanager
|
||||
def profile_section(self, name: str, sync_cuda: bool = True):
|
||||
"""Context manager for profiling a code section."""
|
||||
if not self.enabled:
|
||||
yield
|
||||
return
|
||||
|
||||
if sync_cuda and torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
|
||||
start_event = None
|
||||
end_event = None
|
||||
if torch.cuda.is_available():
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
start_event.record()
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if sync_cuda and torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
|
||||
end_time = time.perf_counter()
|
||||
cpu_time_ms = (end_time - start_time) * 1000
|
||||
|
||||
cuda_time_ms = 0.0
|
||||
if start_event is not None and end_event is not None:
|
||||
end_event.record()
|
||||
torch.cuda.synchronize()
|
||||
cuda_time_ms = start_event.elapsed_time(end_event)
|
||||
|
||||
if name not in self.macro_timings:
|
||||
self.macro_timings[name] = []
|
||||
self.macro_timings[name].append(cpu_time_ms)
|
||||
|
||||
if name not in self.cuda_events:
|
||||
self.cuda_events[name] = []
|
||||
self.cuda_events[name].append((cpu_time_ms, cuda_time_ms))
|
||||
|
||||
def record_memory(self, tag: str = ""):
|
||||
"""Record current GPU memory state."""
|
||||
if not self.enabled or not torch.cuda.is_available():
|
||||
return
|
||||
|
||||
snapshot = {
|
||||
'tag': tag,
|
||||
'iteration': self.current_iteration,
|
||||
'allocated_mb': torch.cuda.memory_allocated() / 1024**2,
|
||||
'reserved_mb': torch.cuda.memory_reserved() / 1024**2,
|
||||
'max_allocated_mb': torch.cuda.max_memory_allocated() / 1024**2,
|
||||
}
|
||||
self.memory_snapshots.append(snapshot)
|
||||
|
||||
def start_pytorch_profiler(self, wait: int = 1, warmup: int = 1, active: int = 3):
|
||||
"""Start PyTorch profiler for operator-level analysis."""
|
||||
if not self.enabled:
|
||||
return nullcontext()
|
||||
|
||||
self.pytorch_profiler = torch.profiler.profile(
|
||||
activities=[
|
||||
torch.profiler.ProfilerActivity.CPU,
|
||||
torch.profiler.ProfilerActivity.CUDA,
|
||||
],
|
||||
schedule=torch.profiler.schedule(
|
||||
wait=wait, warmup=warmup, active=active, repeat=1
|
||||
),
|
||||
on_trace_ready=self._trace_handler,
|
||||
record_shapes=self.profiler_config["record_shapes"],
|
||||
profile_memory=self.profiler_config["profile_memory"],
|
||||
with_stack=self.profiler_config["with_stack"],
|
||||
with_flops=self.profiler_config["with_flops"],
|
||||
with_modules=self.profiler_config["with_modules"],
|
||||
)
|
||||
return self.pytorch_profiler
|
||||
|
||||
def _trace_handler(self, prof):
|
||||
"""Handle profiler trace output."""
|
||||
trace_path = os.path.join(
|
||||
self.output_dir,
|
||||
f"trace_iter_{self.current_iteration}.json"
|
||||
)
|
||||
prof.export_chrome_trace(trace_path)
|
||||
|
||||
# Extract operator statistics
|
||||
key_averages = prof.key_averages(
|
||||
group_by_input_shape=self.profiler_config["group_by_input_shape"]
|
||||
)
|
||||
for evt in key_averages:
|
||||
op_name = evt.key
|
||||
if op_name not in self.operator_stats:
|
||||
self.operator_stats[op_name] = {
|
||||
'count': 0,
|
||||
'cpu_time_total_us': 0,
|
||||
'cuda_time_total_us': 0,
|
||||
'self_cpu_time_total_us': 0,
|
||||
'self_cuda_time_total_us': 0,
|
||||
'cpu_memory_usage': 0,
|
||||
'cuda_memory_usage': 0,
|
||||
'flops': 0,
|
||||
}
|
||||
stats = self.operator_stats[op_name]
|
||||
stats['count'] += evt.count
|
||||
stats['cpu_time_total_us'] += evt.cpu_time_total
|
||||
stats['cuda_time_total_us'] += evt.cuda_time_total
|
||||
stats['self_cpu_time_total_us'] += evt.self_cpu_time_total
|
||||
stats['self_cuda_time_total_us'] += evt.self_cuda_time_total
|
||||
if hasattr(evt, 'cpu_memory_usage'):
|
||||
stats['cpu_memory_usage'] += evt.cpu_memory_usage
|
||||
if hasattr(evt, 'cuda_memory_usage'):
|
||||
stats['cuda_memory_usage'] += evt.cuda_memory_usage
|
||||
if hasattr(evt, 'flops') and evt.flops:
|
||||
stats['flops'] += evt.flops
|
||||
|
||||
def step_profiler(self):
|
||||
"""Step the PyTorch profiler."""
|
||||
if self.pytorch_profiler is not None:
|
||||
self.pytorch_profiler.step()
|
||||
|
||||
def generate_report(self) -> str:
|
||||
"""Generate comprehensive profiling report."""
|
||||
if not self.enabled:
|
||||
return "Profiling disabled."
|
||||
|
||||
report_lines = []
|
||||
report_lines.append("=" * 80)
|
||||
report_lines.append("PERFORMANCE PROFILING REPORT")
|
||||
report_lines.append("=" * 80)
|
||||
report_lines.append("")
|
||||
|
||||
# Macro-level timing summary
|
||||
report_lines.append("-" * 40)
|
||||
report_lines.append("MACRO-LEVEL TIMING SUMMARY")
|
||||
report_lines.append("-" * 40)
|
||||
report_lines.append(f"{'Section':<40} {'Count':>8} {'Total(ms)':>12} {'Avg(ms)':>12} {'CUDA Avg(ms)':>14}")
|
||||
report_lines.append("-" * 86)
|
||||
|
||||
total_time = 0
|
||||
timing_data = []
|
||||
for name, times in sorted(self.macro_timings.items()):
|
||||
cuda_times = [ct for _, ct in self.cuda_events.get(name, [])]
|
||||
avg_time = np.mean(times)
|
||||
avg_cuda = np.mean(cuda_times) if cuda_times else 0
|
||||
total = sum(times)
|
||||
total_time += total
|
||||
timing_data.append({
|
||||
'name': name,
|
||||
'count': len(times),
|
||||
'total_ms': total,
|
||||
'avg_ms': avg_time,
|
||||
'cuda_avg_ms': avg_cuda,
|
||||
'times': times,
|
||||
'cuda_times': cuda_times,
|
||||
})
|
||||
report_lines.append(f"{name:<40} {len(times):>8} {total:>12.2f} {avg_time:>12.2f} {avg_cuda:>14.2f}")
|
||||
|
||||
report_lines.append("-" * 86)
|
||||
report_lines.append(f"{'TOTAL':<40} {'':<8} {total_time:>12.2f}")
|
||||
report_lines.append("")
|
||||
|
||||
# Memory summary
|
||||
if self.memory_snapshots:
|
||||
report_lines.append("-" * 40)
|
||||
report_lines.append("GPU MEMORY SUMMARY")
|
||||
report_lines.append("-" * 40)
|
||||
max_alloc = max(s['max_allocated_mb'] for s in self.memory_snapshots)
|
||||
avg_alloc = np.mean([s['allocated_mb'] for s in self.memory_snapshots])
|
||||
report_lines.append(f"Peak allocated: {max_alloc:>10.2f} MB")
|
||||
report_lines.append(f"Average allocated: {avg_alloc:>10.2f} MB")
|
||||
report_lines.append("")
|
||||
|
||||
# Top operators by CUDA time
|
||||
if self.operator_stats:
|
||||
report_lines.append("-" * 40)
|
||||
report_lines.append("TOP 30 OPERATORS BY CUDA TIME")
|
||||
report_lines.append("-" * 40)
|
||||
sorted_ops = sorted(
|
||||
self.operator_stats.items(),
|
||||
key=lambda x: x[1]['cuda_time_total_us'],
|
||||
reverse=True
|
||||
)[:30]
|
||||
|
||||
report_lines.append(f"{'Operator':<50} {'Count':>8} {'CUDA(ms)':>12} {'CPU(ms)':>12} {'Self CUDA(ms)':>14}")
|
||||
report_lines.append("-" * 96)
|
||||
|
||||
for op_name, stats in sorted_ops:
|
||||
# Truncate long operator names
|
||||
display_name = op_name[:47] + "..." if len(op_name) > 50 else op_name
|
||||
report_lines.append(
|
||||
f"{display_name:<50} {stats['count']:>8} "
|
||||
f"{stats['cuda_time_total_us']/1000:>12.2f} "
|
||||
f"{stats['cpu_time_total_us']/1000:>12.2f} "
|
||||
f"{stats['self_cuda_time_total_us']/1000:>14.2f}"
|
||||
)
|
||||
report_lines.append("")
|
||||
|
||||
# Compute category breakdown
|
||||
report_lines.append("-" * 40)
|
||||
report_lines.append("OPERATOR CATEGORY BREAKDOWN")
|
||||
report_lines.append("-" * 40)
|
||||
|
||||
categories = {
|
||||
'Attention': ['attention', 'softmax', 'bmm', 'baddbmm'],
|
||||
'Convolution': ['conv', 'cudnn'],
|
||||
'Normalization': ['norm', 'layer_norm', 'batch_norm', 'group_norm'],
|
||||
'Activation': ['relu', 'gelu', 'silu', 'sigmoid', 'tanh'],
|
||||
'Linear/GEMM': ['linear', 'addmm', 'mm', 'gemm'],
|
||||
'Memory': ['copy', 'contiguous', 'view', 'reshape', 'permute', 'transpose'],
|
||||
'Elementwise': ['add', 'mul', 'div', 'sub', 'pow', 'exp', 'sqrt'],
|
||||
}
|
||||
|
||||
category_times = {cat: 0.0 for cat in categories}
|
||||
category_times['Other'] = 0.0
|
||||
|
||||
for op_name, stats in self.operator_stats.items():
|
||||
op_lower = op_name.lower()
|
||||
categorized = False
|
||||
for cat, keywords in categories.items():
|
||||
if any(kw in op_lower for kw in keywords):
|
||||
category_times[cat] += stats['cuda_time_total_us']
|
||||
categorized = True
|
||||
break
|
||||
if not categorized:
|
||||
category_times['Other'] += stats['cuda_time_total_us']
|
||||
|
||||
total_op_time = sum(category_times.values())
|
||||
report_lines.append(f"{'Category':<30} {'CUDA Time(ms)':>15} {'Percentage':>12}")
|
||||
report_lines.append("-" * 57)
|
||||
for cat, time_us in sorted(category_times.items(), key=lambda x: -x[1]):
|
||||
pct = (time_us / total_op_time * 100) if total_op_time > 0 else 0
|
||||
report_lines.append(f"{cat:<30} {time_us/1000:>15.2f} {pct:>11.1f}%")
|
||||
report_lines.append("")
|
||||
|
||||
report = "\n".join(report_lines)
|
||||
return report
|
||||
|
||||
def save_results(self):
|
||||
"""Save all profiling results to files."""
|
||||
if not self.enabled:
|
||||
return
|
||||
|
||||
# Save report
|
||||
report = self.generate_report()
|
||||
report_path = os.path.join(self.output_dir, "profiling_report.txt")
|
||||
with open(report_path, 'w') as f:
|
||||
f.write(report)
|
||||
print(f">>> Profiling report saved to: {report_path}")
|
||||
|
||||
# Save detailed JSON data
|
||||
data = {
|
||||
'macro_timings': {
|
||||
name: {
|
||||
'times': times,
|
||||
'cuda_times': [ct for _, ct in self.cuda_events.get(name, [])]
|
||||
}
|
||||
for name, times in self.macro_timings.items()
|
||||
},
|
||||
'memory_snapshots': self.memory_snapshots,
|
||||
'operator_stats': self.operator_stats,
|
||||
}
|
||||
json_path = os.path.join(self.output_dir, "profiling_data.json")
|
||||
with open(json_path, 'w') as f:
|
||||
json.dump(data, f, indent=2)
|
||||
print(f">>> Detailed profiling data saved to: {json_path}")
|
||||
|
||||
# Print summary to console
|
||||
print("\n" + report)
|
||||
|
||||
|
||||
# Global profiler instance
|
||||
_profiler: Optional[ProfilerManager] = None
|
||||
|
||||
def get_profiler() -> ProfilerManager:
|
||||
"""Get the global profiler instance."""
|
||||
global _profiler
|
||||
if _profiler is None:
|
||||
_profiler = ProfilerManager(enabled=False)
|
||||
return _profiler
|
||||
|
||||
def init_profiler(enabled: bool, output_dir: str, profile_detail: str) -> ProfilerManager:
|
||||
"""Initialize the global profiler."""
|
||||
global _profiler
|
||||
_profiler = ProfilerManager(
|
||||
enabled=enabled,
|
||||
output_dir=output_dir,
|
||||
profile_detail=profile_detail,
|
||||
)
|
||||
return _profiler
|
||||
|
||||
|
||||
# ========== Async I/O ==========
|
||||
_io_executor: Optional[ThreadPoolExecutor] = None
|
||||
_io_futures: List[Any] = []
|
||||
@@ -447,28 +81,6 @@ def save_results_async(video: Tensor, filename: str, fps: int = 8) -> None:
|
||||
_io_futures.append(fut)
|
||||
|
||||
|
||||
def _log_to_tb_sync(writer, video_cpu: Tensor, tag: str, fps: int) -> None:
|
||||
"""Synchronous TensorBoard log on CPU tensor (runs in background thread)."""
|
||||
if video_cpu.dim() == 5:
|
||||
n = video_cpu.shape[0]
|
||||
video = video_cpu.permute(2, 0, 1, 3, 4)
|
||||
frame_grids = [
|
||||
torchvision.utils.make_grid(framesheet, nrow=int(n), padding=0)
|
||||
for framesheet in video
|
||||
]
|
||||
grid = torch.stack(frame_grids, dim=0)
|
||||
grid = (grid + 1.0) / 2.0
|
||||
grid = grid.unsqueeze(dim=0)
|
||||
writer.add_video(tag, grid, fps=fps)
|
||||
|
||||
|
||||
def log_to_tensorboard_async(writer, data: Tensor, tag: str, fps: int = 10) -> None:
|
||||
"""Submit TensorBoard logging to background thread pool."""
|
||||
if isinstance(data, torch.Tensor) and data.dim() == 5:
|
||||
data_cpu = data.detach().cpu()
|
||||
fut = _get_io_executor().submit(_log_to_tb_sync, writer, data_cpu, tag, fps)
|
||||
_io_futures.append(fut)
|
||||
|
||||
|
||||
# ========== Original Functions ==========
|
||||
def get_device_from_parameters(module: nn.Module) -> torch.device:
|
||||
@@ -590,9 +202,7 @@ def load_model_checkpoint(model: nn.Module,
|
||||
|
||||
def maybe_cast_module(module: nn.Module | None,
|
||||
dtype: torch.dtype,
|
||||
label: str,
|
||||
profiler: Optional[ProfilerManager] = None,
|
||||
profile_name: Optional[str] = None) -> None:
|
||||
label: str) -> None:
|
||||
if module is None:
|
||||
return
|
||||
try:
|
||||
@@ -603,10 +213,6 @@ def maybe_cast_module(module: nn.Module | None,
|
||||
if param.dtype == dtype:
|
||||
print(f">>> {label} already {dtype}; skip cast")
|
||||
return
|
||||
ctx = nullcontext()
|
||||
if profiler is not None and profile_name:
|
||||
ctx = profiler.profile_section(profile_name)
|
||||
with ctx:
|
||||
module.to(dtype=dtype)
|
||||
print(f">>> {label} cast to {dtype}")
|
||||
|
||||
@@ -825,8 +431,6 @@ def get_latent_z(model, videos: Tensor) -> Tensor:
|
||||
Returns:
|
||||
Tensor: Latent video tensor of shape [B, C, T, H, W].
|
||||
"""
|
||||
profiler = get_profiler()
|
||||
with profiler.profile_section("get_latent_z/encode"):
|
||||
b, c, t, h, w = videos.shape
|
||||
x = rearrange(videos, 'b c t h w -> (b t) c h w')
|
||||
vae_ctx = nullcontext()
|
||||
@@ -941,8 +545,6 @@ def image_guided_synthesis_sim_mode(
|
||||
actions (torch.Tensor): Predicted action sequences [B, T, D] from diffusion decoding.
|
||||
states (torch.Tensor): Predicted state sequences [B, T, D] from diffusion decoding.
|
||||
"""
|
||||
profiler = get_profiler()
|
||||
|
||||
b, _, t, _, _ = noise_shape
|
||||
ddim_sampler = getattr(model, "_ddim_sampler", None)
|
||||
if ddim_sampler is None:
|
||||
@@ -952,7 +554,6 @@ def image_guided_synthesis_sim_mode(
|
||||
|
||||
fs = torch.tensor([fs] * batch_size, dtype=torch.long, device=model.device)
|
||||
|
||||
with profiler.profile_section("synthesis/conditioning_prep"):
|
||||
img = observation['observation.images.top'].permute(0, 2, 1, 3, 4)
|
||||
cond_img = rearrange(img, 'b o c h w -> (b o) c h w')[-1:]
|
||||
if getattr(model, "encoder_bf16", False) and model.device.type == "cuda":
|
||||
@@ -1038,7 +639,6 @@ def image_guided_synthesis_sim_mode(
|
||||
cond_z0 = None
|
||||
|
||||
if ddim_sampler is not None:
|
||||
with profiler.profile_section("synthesis/ddim_sampling"):
|
||||
autocast_ctx = nullcontext()
|
||||
if diffusion_autocast_dtype is not None and model.device.type == "cuda":
|
||||
autocast_ctx = torch.autocast("cuda", dtype=diffusion_autocast_dtype)
|
||||
@@ -1061,7 +661,6 @@ def image_guided_synthesis_sim_mode(
|
||||
**kwargs)
|
||||
|
||||
# Reconstruct from latent to pixel space
|
||||
with profiler.profile_section("synthesis/decode_first_stage"):
|
||||
if getattr(model, "vae_bf16", False):
|
||||
if samples.dtype != torch.bfloat16:
|
||||
samples = samples.to(dtype=torch.bfloat16)
|
||||
@@ -1091,13 +690,8 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
profiler = get_profiler()
|
||||
|
||||
# Create inference and tensorboard dirs
|
||||
# Create inference dir
|
||||
os.makedirs(args.savedir + '/inference', exist_ok=True)
|
||||
log_dir = args.savedir + f"/tensorboard"
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
writer = SummaryWriter(log_dir=log_dir)
|
||||
|
||||
# Load prompt
|
||||
csv_path = os.path.join(args.prompt_dir, f"{args.dataset}.csv")
|
||||
@@ -1110,7 +704,6 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
if os.path.exists(prepared_path):
|
||||
# ---- Fast path: load the fully-prepared model ----
|
||||
print(f">>> Loading prepared model from {prepared_path} ...")
|
||||
with profiler.profile_section("model_loading/prepared"):
|
||||
model = torch.load(prepared_path,
|
||||
map_location=f"cuda:{gpu_no}",
|
||||
weights_only=False,
|
||||
@@ -1122,7 +715,6 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
print(f">>> Prepared model loaded.")
|
||||
else:
|
||||
# ---- Normal path: construct + checkpoint + casting ----
|
||||
with profiler.profile_section("model_loading/config"):
|
||||
config['model']['params']['wma_config']['params'][
|
||||
'use_checkpoint'] = False
|
||||
model = instantiate_from_config(config.model)
|
||||
@@ -1130,7 +722,6 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
|
||||
assert os.path.exists(args.ckpt_path), "Error: checkpoint Not Found!"
|
||||
|
||||
with profiler.profile_section("model_loading/checkpoint"):
|
||||
model = load_model_checkpoint(model, args.ckpt_path,
|
||||
device=f"cuda:{gpu_no}")
|
||||
model.eval()
|
||||
@@ -1143,8 +734,6 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
model.model,
|
||||
torch.bfloat16,
|
||||
"diffusion backbone",
|
||||
profiler=profiler,
|
||||
profile_name="model_loading/diffusion_bf16",
|
||||
)
|
||||
diffusion_autocast_dtype = torch.bfloat16
|
||||
print(">>> diffusion backbone set to bfloat16")
|
||||
@@ -1155,8 +744,6 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
model.first_stage_model,
|
||||
vae_weight_dtype,
|
||||
"VAE",
|
||||
profiler=profiler,
|
||||
profile_name="model_loading/vae_cast",
|
||||
)
|
||||
model.vae_bf16 = args.vae_dtype == "bf16"
|
||||
print(f">>> VAE dtype set to {args.vae_dtype}")
|
||||
@@ -1195,16 +782,12 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
model.cond_stage_model,
|
||||
encoder_weight_dtype,
|
||||
"cond_stage_model",
|
||||
profiler=profiler,
|
||||
profile_name="model_loading/encoder_cond_cast",
|
||||
)
|
||||
if hasattr(model, "embedder") and model.embedder is not None:
|
||||
maybe_cast_module(
|
||||
model.embedder,
|
||||
encoder_weight_dtype,
|
||||
"embedder",
|
||||
profiler=profiler,
|
||||
profile_name="model_loading/encoder_embedder_cast",
|
||||
)
|
||||
model.encoder_bf16 = encoder_bf16
|
||||
model.encoder_mode = encoder_mode
|
||||
@@ -1220,24 +803,18 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
model.image_proj_model,
|
||||
projector_weight_dtype,
|
||||
"image_proj_model",
|
||||
profiler=profiler,
|
||||
profile_name="model_loading/projector_image_cast",
|
||||
)
|
||||
if hasattr(model, "state_projector") and model.state_projector is not None:
|
||||
maybe_cast_module(
|
||||
model.state_projector,
|
||||
projector_weight_dtype,
|
||||
"state_projector",
|
||||
profiler=profiler,
|
||||
profile_name="model_loading/projector_state_cast",
|
||||
)
|
||||
if hasattr(model, "action_projector") and model.action_projector is not None:
|
||||
maybe_cast_module(
|
||||
model.action_projector,
|
||||
projector_weight_dtype,
|
||||
"action_projector",
|
||||
profiler=profiler,
|
||||
profile_name="model_loading/projector_action_cast",
|
||||
)
|
||||
if hasattr(model, "projector_bf16"):
|
||||
model.projector_bf16 = projector_bf16
|
||||
@@ -1269,14 +846,11 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
|
||||
# Build normalizer (always needed, independent of model loading path)
|
||||
logging.info("***** Configing Data *****")
|
||||
with profiler.profile_section("data_loading"):
|
||||
data = instantiate_from_config(config.data)
|
||||
data.setup()
|
||||
print(">>> Dataset is successfully loaded ...")
|
||||
device = get_device_from_parameters(model)
|
||||
|
||||
profiler.record_memory("after_model_load")
|
||||
|
||||
# Run over data
|
||||
assert (args.height % 16 == 0) and (
|
||||
args.width % 16
|
||||
@@ -1290,10 +864,6 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
print(f'>>> Generate {n_frames} frames under each generation ...')
|
||||
noise_shape = [args.bs, channels, n_frames, h, w]
|
||||
|
||||
# Determine profiler iterations
|
||||
profile_active_iters = getattr(args, 'profile_iterations', 3)
|
||||
use_pytorch_profiler = profiler.enabled and profile_active_iters > 0
|
||||
|
||||
# Start inference
|
||||
for idx in range(0, len(df)):
|
||||
sample = df.iloc[idx]
|
||||
@@ -1309,7 +879,6 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
|
||||
# Load transitions to get the initial state later
|
||||
transition_path = get_transition_path(args.prompt_dir, sample)
|
||||
with profiler.profile_section("load_transitions"):
|
||||
with h5py.File(transition_path, 'r') as h5f:
|
||||
transition_dict = {}
|
||||
for key in h5f.keys():
|
||||
@@ -1337,7 +906,6 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
}
|
||||
|
||||
# Obtain initial frame and state
|
||||
with profiler.profile_section("prepare_init_input"):
|
||||
start_idx = 0
|
||||
model_input_fs = ori_fps // fs
|
||||
batch, ori_state_dim, ori_action_dim = prepare_init_input(
|
||||
@@ -1360,24 +928,12 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
# Update observation queues
|
||||
cond_obs_queues = populate_queues(cond_obs_queues, observation)
|
||||
|
||||
# Setup PyTorch profiler context if enabled
|
||||
pytorch_prof_ctx = nullcontext()
|
||||
if use_pytorch_profiler:
|
||||
pytorch_prof_ctx = profiler.start_pytorch_profiler(
|
||||
wait=1, warmup=1, active=profile_active_iters
|
||||
)
|
||||
|
||||
# Multi-round interaction with the world-model
|
||||
with pytorch_prof_ctx:
|
||||
for itr in tqdm(range(args.n_iter)):
|
||||
log_every = max(1, args.step_log_every)
|
||||
log_step = (itr % log_every == 0)
|
||||
profiler.current_iteration = itr
|
||||
profiler.record_memory(f"iter_{itr}_start")
|
||||
|
||||
with profiler.profile_section("iteration_total"):
|
||||
# Get observation
|
||||
with profiler.profile_section("prepare_observation"):
|
||||
observation = {
|
||||
'observation.images.top':
|
||||
torch.stack(list(
|
||||
@@ -1394,7 +950,6 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
# Use world-model in policy to generate action
|
||||
if log_step:
|
||||
print(f'>>> Step {itr}: generating actions ...')
|
||||
with profiler.profile_section("action_generation"):
|
||||
pred_videos_0, pred_actions, _ = image_guided_synthesis_sim_mode(
|
||||
model,
|
||||
sample['instruction'],
|
||||
@@ -1412,7 +967,6 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
diffusion_autocast_dtype=diffusion_autocast_dtype)
|
||||
|
||||
# Update future actions in the observation queues
|
||||
with profiler.profile_section("update_action_queues"):
|
||||
for act_idx in range(len(pred_actions[0])):
|
||||
obs_update = {'action': pred_actions[0][act_idx:act_idx + 1]}
|
||||
obs_update['action'][:, ori_action_dim:] = 0.0
|
||||
@@ -1420,7 +974,6 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
obs_update)
|
||||
|
||||
# Collect data for interacting the world-model using the predicted actions
|
||||
with profiler.profile_section("prepare_wm_observation"):
|
||||
observation = {
|
||||
'observation.images.top':
|
||||
torch.stack(list(
|
||||
@@ -1437,7 +990,6 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
# Interaction with the world-model
|
||||
if log_step:
|
||||
print(f'>>> Step {itr}: interacting with world model ...')
|
||||
with profiler.profile_section("world_model_interaction"):
|
||||
pred_videos_1, _, pred_states = image_guided_synthesis_sim_mode(
|
||||
model,
|
||||
"",
|
||||
@@ -1454,7 +1006,6 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
guidance_rescale=args.guidance_rescale,
|
||||
diffusion_autocast_dtype=diffusion_autocast_dtype)
|
||||
|
||||
with profiler.profile_section("update_state_queues"):
|
||||
for step_idx in range(args.exe_steps):
|
||||
obs_update = {
|
||||
'observation.images.top':
|
||||
@@ -1470,20 +1021,6 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
obs_update)
|
||||
|
||||
# Save the imagen videos for decision-making (async)
|
||||
with profiler.profile_section("save_results"):
|
||||
sample_tag = f"{args.dataset}-vid{sample['videoid']}-dm-fs-{fs}/itr-{itr}"
|
||||
log_to_tensorboard_async(writer,
|
||||
pred_videos_0,
|
||||
sample_tag,
|
||||
fps=args.save_fps)
|
||||
# Save videos environment changes via world-model interaction
|
||||
sample_tag = f"{args.dataset}-vid{sample['videoid']}-wd-fs-{fs}/itr-{itr}"
|
||||
log_to_tensorboard_async(writer,
|
||||
pred_videos_1,
|
||||
sample_tag,
|
||||
fps=args.save_fps)
|
||||
|
||||
# Save the imagen videos for decision-making
|
||||
sample_video_file = f'{video_save_dir}/dm/{fs}/itr-{itr}.mp4'
|
||||
save_results_async(pred_videos_0,
|
||||
sample_video_file,
|
||||
@@ -1498,24 +1035,13 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
# Collect the result of world-model interactions
|
||||
wm_video.append(pred_videos_1[:, :, :args.exe_steps].cpu())
|
||||
|
||||
profiler.record_memory(f"iter_{itr}_end")
|
||||
profiler.step_profiler()
|
||||
|
||||
full_video = torch.cat(wm_video, dim=2)
|
||||
sample_tag = f"{args.dataset}-vid{sample['videoid']}-wd-fs-{fs}/full"
|
||||
log_to_tensorboard_async(writer,
|
||||
full_video,
|
||||
sample_tag,
|
||||
fps=args.save_fps)
|
||||
sample_full_video_file = f"{video_save_dir}/../{sample['videoid']}_full_fs{fs}.mp4"
|
||||
save_results_async(full_video, sample_full_video_file, fps=args.save_fps)
|
||||
|
||||
# Wait for all async I/O to complete before profiling report
|
||||
# Wait for all async I/O to complete
|
||||
_flush_io()
|
||||
|
||||
# Save profiling results
|
||||
profiler.save_results()
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser()
|
||||
@@ -1704,32 +1230,6 @@ def get_parser():
|
||||
type=int,
|
||||
default=8,
|
||||
help="fps for the saving video")
|
||||
# Profiling arguments
|
||||
parser.add_argument(
|
||||
"--profile",
|
||||
action='store_true',
|
||||
default=False,
|
||||
help="Enable performance profiling (macro and operator-level analysis)."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--profile_output_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Directory to save profiling results. Defaults to {savedir}/profile_output."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--profile_iterations",
|
||||
type=int,
|
||||
default=3,
|
||||
help="Number of iterations to run PyTorch profiler's active phase for operator-level analysis."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--profile_detail",
|
||||
type=str,
|
||||
choices=["light", "full"],
|
||||
default="light",
|
||||
help="Profiling detail level. Use 'full' for shapes/stacks/memory/flops."
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
@@ -1741,15 +1241,5 @@ if __name__ == '__main__':
|
||||
seed = random.randint(0, 2**31)
|
||||
seed_everything(seed)
|
||||
|
||||
# Initialize profiler
|
||||
profile_output_dir = args.profile_output_dir
|
||||
if profile_output_dir is None:
|
||||
profile_output_dir = os.path.join(args.savedir, "profile_output")
|
||||
init_profiler(
|
||||
enabled=args.profile,
|
||||
output_dir=profile_output_dir,
|
||||
profile_detail=args.profile_detail,
|
||||
)
|
||||
|
||||
rank, gpu_num = 0, 1
|
||||
run_inference(args, gpu_num, rank)
|
||||
|
||||
Reference in New Issue
Block a user