From f1f92072e66212856504e757b0d20b4558283d3b Mon Sep 17 00:00:00 2001 From: qhy <2728290997@qq.com> Date: Tue, 10 Feb 2026 11:28:26 +0800 Subject: [PATCH] remove profile --- scripts/evaluation/world_model_interaction.py | 998 +++++------------- 1 file changed, 244 insertions(+), 754 deletions(-) diff --git a/scripts/evaluation/world_model_interaction.py b/scripts/evaluation/world_model_interaction.py index 00bd36b..281693c 100644 --- a/scripts/evaluation/world_model_interaction.py +++ b/scripts/evaluation/world_model_interaction.py @@ -9,12 +9,9 @@ import logging import einops import warnings import imageio -import time -import json import atexit from concurrent.futures import ThreadPoolExecutor -from contextlib import contextmanager, nullcontext -from dataclasses import dataclass, field, asdict +from contextlib import nullcontext from typing import Optional, Dict, List, Any, Mapping from pytorch_lightning import seed_everything @@ -26,375 +23,12 @@ from torch import nn from eval_utils import populate_queues from collections import deque from torch import Tensor -from torch.utils.tensorboard import SummaryWriter from PIL import Image from unifolm_wma.models.samplers.ddim import DDIMSampler from unifolm_wma.utils.utils import instantiate_from_config -# ========== Profiling Infrastructure ========== -@dataclass -class TimingRecord: - """Record for a single timing measurement.""" - name: str - start_time: float = 0.0 - end_time: float = 0.0 - cuda_time_ms: float = 0.0 - count: int = 0 - children: List['TimingRecord'] = field(default_factory=list) - - @property - def cpu_time_ms(self) -> float: - return (self.end_time - self.start_time) * 1000 - - def to_dict(self) -> dict: - return { - 'name': self.name, - 'cpu_time_ms': self.cpu_time_ms, - 'cuda_time_ms': self.cuda_time_ms, - 'count': self.count, - 'children': [c.to_dict() for c in self.children] - } - - -class ProfilerManager: - """Manages macro and micro-level profiling.""" - - def __init__( - self, - enabled: bool = False, - output_dir: str = "./profile_output", - profile_detail: str = "light", - ): - self.enabled = enabled - self.output_dir = output_dir - self.profile_detail = profile_detail - self.macro_timings: Dict[str, List[float]] = {} - self.cuda_events: Dict[str, List[tuple]] = {} - self.memory_snapshots: List[Dict] = [] - self.pytorch_profiler = None - self.current_iteration = 0 - self.operator_stats: Dict[str, Dict] = {} - self.profiler_config = self._build_profiler_config(profile_detail) - - if enabled: - os.makedirs(output_dir, exist_ok=True) - - def _build_profiler_config(self, profile_detail: str) -> Dict[str, Any]: - """Return profiler settings based on the requested detail level.""" - if profile_detail not in ("light", "full"): - raise ValueError(f"Unsupported profile_detail: {profile_detail}") - if profile_detail == "full": - return { - "record_shapes": True, - "profile_memory": True, - "with_stack": True, - "with_flops": True, - "with_modules": True, - "group_by_input_shape": True, - } - return { - "record_shapes": False, - "profile_memory": False, - "with_stack": False, - "with_flops": False, - "with_modules": False, - "group_by_input_shape": False, - } - - @contextmanager - def profile_section(self, name: str, sync_cuda: bool = True): - """Context manager for profiling a code section.""" - if not self.enabled: - yield - return - - if sync_cuda and torch.cuda.is_available(): - torch.cuda.synchronize() - - start_event = None - end_event = None - if torch.cuda.is_available(): - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - start_event.record() - - start_time = time.perf_counter() - - try: - yield - finally: - if sync_cuda and torch.cuda.is_available(): - torch.cuda.synchronize() - - end_time = time.perf_counter() - cpu_time_ms = (end_time - start_time) * 1000 - - cuda_time_ms = 0.0 - if start_event is not None and end_event is not None: - end_event.record() - torch.cuda.synchronize() - cuda_time_ms = start_event.elapsed_time(end_event) - - if name not in self.macro_timings: - self.macro_timings[name] = [] - self.macro_timings[name].append(cpu_time_ms) - - if name not in self.cuda_events: - self.cuda_events[name] = [] - self.cuda_events[name].append((cpu_time_ms, cuda_time_ms)) - - def record_memory(self, tag: str = ""): - """Record current GPU memory state.""" - if not self.enabled or not torch.cuda.is_available(): - return - - snapshot = { - 'tag': tag, - 'iteration': self.current_iteration, - 'allocated_mb': torch.cuda.memory_allocated() / 1024**2, - 'reserved_mb': torch.cuda.memory_reserved() / 1024**2, - 'max_allocated_mb': torch.cuda.max_memory_allocated() / 1024**2, - } - self.memory_snapshots.append(snapshot) - - def start_pytorch_profiler(self, wait: int = 1, warmup: int = 1, active: int = 3): - """Start PyTorch profiler for operator-level analysis.""" - if not self.enabled: - return nullcontext() - - self.pytorch_profiler = torch.profiler.profile( - activities=[ - torch.profiler.ProfilerActivity.CPU, - torch.profiler.ProfilerActivity.CUDA, - ], - schedule=torch.profiler.schedule( - wait=wait, warmup=warmup, active=active, repeat=1 - ), - on_trace_ready=self._trace_handler, - record_shapes=self.profiler_config["record_shapes"], - profile_memory=self.profiler_config["profile_memory"], - with_stack=self.profiler_config["with_stack"], - with_flops=self.profiler_config["with_flops"], - with_modules=self.profiler_config["with_modules"], - ) - return self.pytorch_profiler - - def _trace_handler(self, prof): - """Handle profiler trace output.""" - trace_path = os.path.join( - self.output_dir, - f"trace_iter_{self.current_iteration}.json" - ) - prof.export_chrome_trace(trace_path) - - # Extract operator statistics - key_averages = prof.key_averages( - group_by_input_shape=self.profiler_config["group_by_input_shape"] - ) - for evt in key_averages: - op_name = evt.key - if op_name not in self.operator_stats: - self.operator_stats[op_name] = { - 'count': 0, - 'cpu_time_total_us': 0, - 'cuda_time_total_us': 0, - 'self_cpu_time_total_us': 0, - 'self_cuda_time_total_us': 0, - 'cpu_memory_usage': 0, - 'cuda_memory_usage': 0, - 'flops': 0, - } - stats = self.operator_stats[op_name] - stats['count'] += evt.count - stats['cpu_time_total_us'] += evt.cpu_time_total - stats['cuda_time_total_us'] += evt.cuda_time_total - stats['self_cpu_time_total_us'] += evt.self_cpu_time_total - stats['self_cuda_time_total_us'] += evt.self_cuda_time_total - if hasattr(evt, 'cpu_memory_usage'): - stats['cpu_memory_usage'] += evt.cpu_memory_usage - if hasattr(evt, 'cuda_memory_usage'): - stats['cuda_memory_usage'] += evt.cuda_memory_usage - if hasattr(evt, 'flops') and evt.flops: - stats['flops'] += evt.flops - - def step_profiler(self): - """Step the PyTorch profiler.""" - if self.pytorch_profiler is not None: - self.pytorch_profiler.step() - - def generate_report(self) -> str: - """Generate comprehensive profiling report.""" - if not self.enabled: - return "Profiling disabled." - - report_lines = [] - report_lines.append("=" * 80) - report_lines.append("PERFORMANCE PROFILING REPORT") - report_lines.append("=" * 80) - report_lines.append("") - - # Macro-level timing summary - report_lines.append("-" * 40) - report_lines.append("MACRO-LEVEL TIMING SUMMARY") - report_lines.append("-" * 40) - report_lines.append(f"{'Section':<40} {'Count':>8} {'Total(ms)':>12} {'Avg(ms)':>12} {'CUDA Avg(ms)':>14}") - report_lines.append("-" * 86) - - total_time = 0 - timing_data = [] - for name, times in sorted(self.macro_timings.items()): - cuda_times = [ct for _, ct in self.cuda_events.get(name, [])] - avg_time = np.mean(times) - avg_cuda = np.mean(cuda_times) if cuda_times else 0 - total = sum(times) - total_time += total - timing_data.append({ - 'name': name, - 'count': len(times), - 'total_ms': total, - 'avg_ms': avg_time, - 'cuda_avg_ms': avg_cuda, - 'times': times, - 'cuda_times': cuda_times, - }) - report_lines.append(f"{name:<40} {len(times):>8} {total:>12.2f} {avg_time:>12.2f} {avg_cuda:>14.2f}") - - report_lines.append("-" * 86) - report_lines.append(f"{'TOTAL':<40} {'':<8} {total_time:>12.2f}") - report_lines.append("") - - # Memory summary - if self.memory_snapshots: - report_lines.append("-" * 40) - report_lines.append("GPU MEMORY SUMMARY") - report_lines.append("-" * 40) - max_alloc = max(s['max_allocated_mb'] for s in self.memory_snapshots) - avg_alloc = np.mean([s['allocated_mb'] for s in self.memory_snapshots]) - report_lines.append(f"Peak allocated: {max_alloc:>10.2f} MB") - report_lines.append(f"Average allocated: {avg_alloc:>10.2f} MB") - report_lines.append("") - - # Top operators by CUDA time - if self.operator_stats: - report_lines.append("-" * 40) - report_lines.append("TOP 30 OPERATORS BY CUDA TIME") - report_lines.append("-" * 40) - sorted_ops = sorted( - self.operator_stats.items(), - key=lambda x: x[1]['cuda_time_total_us'], - reverse=True - )[:30] - - report_lines.append(f"{'Operator':<50} {'Count':>8} {'CUDA(ms)':>12} {'CPU(ms)':>12} {'Self CUDA(ms)':>14}") - report_lines.append("-" * 96) - - for op_name, stats in sorted_ops: - # Truncate long operator names - display_name = op_name[:47] + "..." if len(op_name) > 50 else op_name - report_lines.append( - f"{display_name:<50} {stats['count']:>8} " - f"{stats['cuda_time_total_us']/1000:>12.2f} " - f"{stats['cpu_time_total_us']/1000:>12.2f} " - f"{stats['self_cuda_time_total_us']/1000:>14.2f}" - ) - report_lines.append("") - - # Compute category breakdown - report_lines.append("-" * 40) - report_lines.append("OPERATOR CATEGORY BREAKDOWN") - report_lines.append("-" * 40) - - categories = { - 'Attention': ['attention', 'softmax', 'bmm', 'baddbmm'], - 'Convolution': ['conv', 'cudnn'], - 'Normalization': ['norm', 'layer_norm', 'batch_norm', 'group_norm'], - 'Activation': ['relu', 'gelu', 'silu', 'sigmoid', 'tanh'], - 'Linear/GEMM': ['linear', 'addmm', 'mm', 'gemm'], - 'Memory': ['copy', 'contiguous', 'view', 'reshape', 'permute', 'transpose'], - 'Elementwise': ['add', 'mul', 'div', 'sub', 'pow', 'exp', 'sqrt'], - } - - category_times = {cat: 0.0 for cat in categories} - category_times['Other'] = 0.0 - - for op_name, stats in self.operator_stats.items(): - op_lower = op_name.lower() - categorized = False - for cat, keywords in categories.items(): - if any(kw in op_lower for kw in keywords): - category_times[cat] += stats['cuda_time_total_us'] - categorized = True - break - if not categorized: - category_times['Other'] += stats['cuda_time_total_us'] - - total_op_time = sum(category_times.values()) - report_lines.append(f"{'Category':<30} {'CUDA Time(ms)':>15} {'Percentage':>12}") - report_lines.append("-" * 57) - for cat, time_us in sorted(category_times.items(), key=lambda x: -x[1]): - pct = (time_us / total_op_time * 100) if total_op_time > 0 else 0 - report_lines.append(f"{cat:<30} {time_us/1000:>15.2f} {pct:>11.1f}%") - report_lines.append("") - - report = "\n".join(report_lines) - return report - - def save_results(self): - """Save all profiling results to files.""" - if not self.enabled: - return - - # Save report - report = self.generate_report() - report_path = os.path.join(self.output_dir, "profiling_report.txt") - with open(report_path, 'w') as f: - f.write(report) - print(f">>> Profiling report saved to: {report_path}") - - # Save detailed JSON data - data = { - 'macro_timings': { - name: { - 'times': times, - 'cuda_times': [ct for _, ct in self.cuda_events.get(name, [])] - } - for name, times in self.macro_timings.items() - }, - 'memory_snapshots': self.memory_snapshots, - 'operator_stats': self.operator_stats, - } - json_path = os.path.join(self.output_dir, "profiling_data.json") - with open(json_path, 'w') as f: - json.dump(data, f, indent=2) - print(f">>> Detailed profiling data saved to: {json_path}") - - # Print summary to console - print("\n" + report) - - -# Global profiler instance -_profiler: Optional[ProfilerManager] = None - -def get_profiler() -> ProfilerManager: - """Get the global profiler instance.""" - global _profiler - if _profiler is None: - _profiler = ProfilerManager(enabled=False) - return _profiler - -def init_profiler(enabled: bool, output_dir: str, profile_detail: str) -> ProfilerManager: - """Initialize the global profiler.""" - global _profiler - _profiler = ProfilerManager( - enabled=enabled, - output_dir=output_dir, - profile_detail=profile_detail, - ) - return _profiler - - # ========== Async I/O ========== _io_executor: Optional[ThreadPoolExecutor] = None _io_futures: List[Any] = [] @@ -447,28 +81,6 @@ def save_results_async(video: Tensor, filename: str, fps: int = 8) -> None: _io_futures.append(fut) -def _log_to_tb_sync(writer, video_cpu: Tensor, tag: str, fps: int) -> None: - """Synchronous TensorBoard log on CPU tensor (runs in background thread).""" - if video_cpu.dim() == 5: - n = video_cpu.shape[0] - video = video_cpu.permute(2, 0, 1, 3, 4) - frame_grids = [ - torchvision.utils.make_grid(framesheet, nrow=int(n), padding=0) - for framesheet in video - ] - grid = torch.stack(frame_grids, dim=0) - grid = (grid + 1.0) / 2.0 - grid = grid.unsqueeze(dim=0) - writer.add_video(tag, grid, fps=fps) - - -def log_to_tensorboard_async(writer, data: Tensor, tag: str, fps: int = 10) -> None: - """Submit TensorBoard logging to background thread pool.""" - if isinstance(data, torch.Tensor) and data.dim() == 5: - data_cpu = data.detach().cpu() - fut = _get_io_executor().submit(_log_to_tb_sync, writer, data_cpu, tag, fps) - _io_futures.append(fut) - # ========== Original Functions ========== def get_device_from_parameters(module: nn.Module) -> torch.device: @@ -590,9 +202,7 @@ def load_model_checkpoint(model: nn.Module, def maybe_cast_module(module: nn.Module | None, dtype: torch.dtype, - label: str, - profiler: Optional[ProfilerManager] = None, - profile_name: Optional[str] = None) -> None: + label: str) -> None: if module is None: return try: @@ -603,11 +213,7 @@ def maybe_cast_module(module: nn.Module | None, if param.dtype == dtype: print(f">>> {label} already {dtype}; skip cast") return - ctx = nullcontext() - if profiler is not None and profile_name: - ctx = profiler.profile_section(profile_name) - with ctx: - module.to(dtype=dtype) + module.to(dtype=dtype) print(f">>> {label} cast to {dtype}") @@ -825,16 +431,14 @@ def get_latent_z(model, videos: Tensor) -> Tensor: Returns: Tensor: Latent video tensor of shape [B, C, T, H, W]. """ - profiler = get_profiler() - with profiler.profile_section("get_latent_z/encode"): - b, c, t, h, w = videos.shape - x = rearrange(videos, 'b c t h w -> (b t) c h w') - vae_ctx = nullcontext() - if getattr(model, "vae_bf16", False) and model.device.type == "cuda": - vae_ctx = torch.autocast("cuda", dtype=torch.bfloat16) - with vae_ctx: - z = model.encode_first_stage(x) - z = rearrange(z, '(b t) c h w -> b c t h w', b=b, t=t) + b, c, t, h, w = videos.shape + x = rearrange(videos, 'b c t h w -> (b t) c h w') + vae_ctx = nullcontext() + if getattr(model, "vae_bf16", False) and model.device.type == "cuda": + vae_ctx = torch.autocast("cuda", dtype=torch.bfloat16) + with vae_ctx: + z = model.encode_first_stage(x) + z = rearrange(z, '(b t) c h w -> b c t h w', b=b, t=t) return z @@ -941,8 +545,6 @@ def image_guided_synthesis_sim_mode( actions (torch.Tensor): Predicted action sequences [B, T, D] from diffusion decoding. states (torch.Tensor): Predicted state sequences [B, T, D] from diffusion decoding. """ - profiler = get_profiler() - b, _, t, _, _ = noise_shape ddim_sampler = getattr(model, "_ddim_sampler", None) if ddim_sampler is None: @@ -952,85 +554,84 @@ def image_guided_synthesis_sim_mode( fs = torch.tensor([fs] * batch_size, dtype=torch.long, device=model.device) - with profiler.profile_section("synthesis/conditioning_prep"): - img = observation['observation.images.top'].permute(0, 2, 1, 3, 4) - cond_img = rearrange(img, 'b o c h w -> (b o) c h w')[-1:] - if getattr(model, "encoder_bf16", False) and model.device.type == "cuda": - if getattr(model, "encoder_mode", "autocast") == "autocast": - preprocess_ctx = torch.autocast("cuda", enabled=False) - with preprocess_ctx: - cond_img_fp32 = cond_img.float() - if hasattr(model.embedder, "preprocess"): - preprocessed = model.embedder.preprocess(cond_img_fp32) - else: - preprocessed = cond_img_fp32 - - if hasattr(model.embedder, - "encode_with_vision_transformer") and hasattr( - model.embedder, "preprocess"): - original_preprocess = model.embedder.preprocess - try: - model.embedder.preprocess = lambda x: x - with torch.autocast("cuda", dtype=torch.bfloat16): - cond_img_emb = model.embedder.encode_with_vision_transformer( - preprocessed) - finally: - model.embedder.preprocess = original_preprocess + img = observation['observation.images.top'].permute(0, 2, 1, 3, 4) + cond_img = rearrange(img, 'b o c h w -> (b o) c h w')[-1:] + if getattr(model, "encoder_bf16", False) and model.device.type == "cuda": + if getattr(model, "encoder_mode", "autocast") == "autocast": + preprocess_ctx = torch.autocast("cuda", enabled=False) + with preprocess_ctx: + cond_img_fp32 = cond_img.float() + if hasattr(model.embedder, "preprocess"): + preprocessed = model.embedder.preprocess(cond_img_fp32) else: + preprocessed = cond_img_fp32 + + if hasattr(model.embedder, + "encode_with_vision_transformer") and hasattr( + model.embedder, "preprocess"): + original_preprocess = model.embedder.preprocess + try: + model.embedder.preprocess = lambda x: x with torch.autocast("cuda", dtype=torch.bfloat16): - cond_img_emb = model.embedder(preprocessed) + cond_img_emb = model.embedder.encode_with_vision_transformer( + preprocessed) + finally: + model.embedder.preprocess = original_preprocess else: with torch.autocast("cuda", dtype=torch.bfloat16): - cond_img_emb = model.embedder(cond_img) + cond_img_emb = model.embedder(preprocessed) else: - cond_img_emb = model.embedder(cond_img) + with torch.autocast("cuda", dtype=torch.bfloat16): + cond_img_emb = model.embedder(cond_img) + else: + cond_img_emb = model.embedder(cond_img) - if model.model.conditioning_key == 'hybrid': - z = get_latent_z(model, img.permute(0, 2, 1, 3, 4)) - img_cat_cond = z[:, :, -1:, :, :] - img_cat_cond = repeat(img_cat_cond, - 'b c t h w -> b c (repeat t) h w', - repeat=noise_shape[2]) - cond = {"c_concat": [img_cat_cond]} + if model.model.conditioning_key == 'hybrid': + z = get_latent_z(model, img.permute(0, 2, 1, 3, 4)) + img_cat_cond = z[:, :, -1:, :, :] + img_cat_cond = repeat(img_cat_cond, + 'b c t h w -> b c (repeat t) h w', + repeat=noise_shape[2]) + cond = {"c_concat": [img_cat_cond]} - if not text_input: - prompts = [""] * batch_size - encoder_ctx = nullcontext() - if getattr(model, "encoder_bf16", False) and model.device.type == "cuda": - encoder_ctx = torch.autocast("cuda", dtype=torch.bfloat16) - with encoder_ctx: - cond_ins_emb = model.get_learned_conditioning(prompts) - target_dtype = cond_ins_emb.dtype + if not text_input: + prompts = [""] * batch_size + encoder_ctx = nullcontext() + if getattr(model, "encoder_bf16", False) and model.device.type == "cuda": + encoder_ctx = torch.autocast("cuda", dtype=torch.bfloat16) + with encoder_ctx: + cond_ins_emb = model.get_learned_conditioning(prompts) + target_dtype = cond_ins_emb.dtype - cond_img_emb = model._projector_forward(model.image_proj_model, - cond_img_emb, target_dtype) + cond_img_emb = model._projector_forward(model.image_proj_model, + cond_img_emb, target_dtype) - cond_state_emb = model._projector_forward( - model.state_projector, observation['observation.state'], - target_dtype) - cond_state_emb = cond_state_emb + model.agent_state_pos_emb.to( - dtype=target_dtype) + cond_state_emb = model._projector_forward( + model.state_projector, observation['observation.state'], + target_dtype) + cond_state_emb = cond_state_emb + model.agent_state_pos_emb.to( + dtype=target_dtype) - cond_action_emb = model._projector_forward( - model.action_projector, observation['action'], target_dtype) - cond_action_emb = cond_action_emb + model.agent_action_pos_emb.to( - dtype=target_dtype) + cond_action_emb = model._projector_forward( + model.action_projector, observation['action'], target_dtype) + cond_action_emb = cond_action_emb + model.agent_action_pos_emb.to( + dtype=target_dtype) - if not sim_mode: - cond_action_emb = torch.zeros_like(cond_action_emb) + if not sim_mode: + cond_action_emb = torch.zeros_like(cond_action_emb) - cond["c_crossattn"] = [ - torch.cat( - [cond_state_emb, cond_action_emb, cond_ins_emb, cond_img_emb], - dim=1) - ] - cond["c_crossattn_action"] = [ - observation['observation.images.top'][:, :, - -model.n_obs_steps_acting:], - observation['observation.state'][:, -model.n_obs_steps_acting:], - sim_mode, - False, - ] + cond["c_crossattn"] = [ + torch.cat( + [cond_state_emb, cond_action_emb, cond_ins_emb, cond_img_emb], + dim=1) + ] + cond["c_crossattn_action"] = [ + observation['observation.images.top'][:, :, + -model.n_obs_steps_acting:], + observation['observation.state'][:, -model.n_obs_steps_acting:], + sim_mode, + False, + ] uc = None kwargs.update({"unconditional_conditioning_img_nonetext": None}) @@ -1038,42 +639,40 @@ def image_guided_synthesis_sim_mode( cond_z0 = None if ddim_sampler is not None: - with profiler.profile_section("synthesis/ddim_sampling"): - autocast_ctx = nullcontext() - if diffusion_autocast_dtype is not None and model.device.type == "cuda": - autocast_ctx = torch.autocast("cuda", dtype=diffusion_autocast_dtype) - with autocast_ctx: - samples, actions, states, intermedia = ddim_sampler.sample( - S=ddim_steps, - conditioning=cond, - batch_size=batch_size, - shape=noise_shape[1:], - verbose=False, - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=uc, - eta=ddim_eta, - cfg_img=None, - mask=cond_mask, - x0=cond_z0, - fs=fs, - timestep_spacing=timestep_spacing, - guidance_rescale=guidance_rescale, - **kwargs) + autocast_ctx = nullcontext() + if diffusion_autocast_dtype is not None and model.device.type == "cuda": + autocast_ctx = torch.autocast("cuda", dtype=diffusion_autocast_dtype) + with autocast_ctx: + samples, actions, states, intermedia = ddim_sampler.sample( + S=ddim_steps, + conditioning=cond, + batch_size=batch_size, + shape=noise_shape[1:], + verbose=False, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=uc, + eta=ddim_eta, + cfg_img=None, + mask=cond_mask, + x0=cond_z0, + fs=fs, + timestep_spacing=timestep_spacing, + guidance_rescale=guidance_rescale, + **kwargs) # Reconstruct from latent to pixel space - with profiler.profile_section("synthesis/decode_first_stage"): - if getattr(model, "vae_bf16", False): - if samples.dtype != torch.bfloat16: - samples = samples.to(dtype=torch.bfloat16) - vae_ctx = nullcontext() - if model.device.type == "cuda": - vae_ctx = torch.autocast("cuda", dtype=torch.bfloat16) - with vae_ctx: - batch_images = model.decode_first_stage(samples) - else: - if samples.dtype != torch.float32: - samples = samples.float() + if getattr(model, "vae_bf16", False): + if samples.dtype != torch.bfloat16: + samples = samples.to(dtype=torch.bfloat16) + vae_ctx = nullcontext() + if model.device.type == "cuda": + vae_ctx = torch.autocast("cuda", dtype=torch.bfloat16) + with vae_ctx: batch_images = model.decode_first_stage(samples) + else: + if samples.dtype != torch.float32: + samples = samples.float() + batch_images = model.decode_first_stage(samples) batch_variants = batch_images return batch_variants, actions, states @@ -1091,13 +690,8 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None: Returns: None """ - profiler = get_profiler() - - # Create inference and tensorboard dirs + # Create inference dir os.makedirs(args.savedir + '/inference', exist_ok=True) - log_dir = args.savedir + f"/tensorboard" - os.makedirs(log_dir, exist_ok=True) - writer = SummaryWriter(log_dir=log_dir) # Load prompt csv_path = os.path.join(args.prompt_dir, f"{args.dataset}.csv") @@ -1110,11 +704,10 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None: if os.path.exists(prepared_path): # ---- Fast path: load the fully-prepared model ---- print(f">>> Loading prepared model from {prepared_path} ...") - with profiler.profile_section("model_loading/prepared"): - model = torch.load(prepared_path, - map_location=f"cuda:{gpu_no}", - weights_only=False, - mmap=True) + model = torch.load(prepared_path, + map_location=f"cuda:{gpu_no}", + weights_only=False, + mmap=True) model.eval() diffusion_autocast_dtype = (torch.bfloat16 if args.diffusion_dtype == "bf16" @@ -1122,17 +715,15 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None: print(f">>> Prepared model loaded.") else: # ---- Normal path: construct + checkpoint + casting ---- - with profiler.profile_section("model_loading/config"): - config['model']['params']['wma_config']['params'][ - 'use_checkpoint'] = False - model = instantiate_from_config(config.model) - model.perframe_ae = args.perframe_ae + config['model']['params']['wma_config']['params'][ + 'use_checkpoint'] = False + model = instantiate_from_config(config.model) + model.perframe_ae = args.perframe_ae assert os.path.exists(args.ckpt_path), "Error: checkpoint Not Found!" - with profiler.profile_section("model_loading/checkpoint"): - model = load_model_checkpoint(model, args.ckpt_path, - device=f"cuda:{gpu_no}") + model = load_model_checkpoint(model, args.ckpt_path, + device=f"cuda:{gpu_no}") model.eval() model = model.cuda(gpu_no) # move residual buffers not in state_dict print(f'>>> Load pre-trained model ...') @@ -1143,8 +734,6 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None: model.model, torch.bfloat16, "diffusion backbone", - profiler=profiler, - profile_name="model_loading/diffusion_bf16", ) diffusion_autocast_dtype = torch.bfloat16 print(">>> diffusion backbone set to bfloat16") @@ -1155,8 +744,6 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None: model.first_stage_model, vae_weight_dtype, "VAE", - profiler=profiler, - profile_name="model_loading/vae_cast", ) model.vae_bf16 = args.vae_dtype == "bf16" print(f">>> VAE dtype set to {args.vae_dtype}") @@ -1195,16 +782,12 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None: model.cond_stage_model, encoder_weight_dtype, "cond_stage_model", - profiler=profiler, - profile_name="model_loading/encoder_cond_cast", ) if hasattr(model, "embedder") and model.embedder is not None: maybe_cast_module( model.embedder, encoder_weight_dtype, "embedder", - profiler=profiler, - profile_name="model_loading/encoder_embedder_cast", ) model.encoder_bf16 = encoder_bf16 model.encoder_mode = encoder_mode @@ -1220,24 +803,18 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None: model.image_proj_model, projector_weight_dtype, "image_proj_model", - profiler=profiler, - profile_name="model_loading/projector_image_cast", ) if hasattr(model, "state_projector") and model.state_projector is not None: maybe_cast_module( model.state_projector, projector_weight_dtype, "state_projector", - profiler=profiler, - profile_name="model_loading/projector_state_cast", ) if hasattr(model, "action_projector") and model.action_projector is not None: maybe_cast_module( model.action_projector, projector_weight_dtype, "action_projector", - profiler=profiler, - profile_name="model_loading/projector_action_cast", ) if hasattr(model, "projector_bf16"): model.projector_bf16 = projector_bf16 @@ -1269,14 +846,11 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None: # Build normalizer (always needed, independent of model loading path) logging.info("***** Configing Data *****") - with profiler.profile_section("data_loading"): - data = instantiate_from_config(config.data) - data.setup() + data = instantiate_from_config(config.data) + data.setup() print(">>> Dataset is successfully loaded ...") device = get_device_from_parameters(model) - profiler.record_memory("after_model_load") - # Run over data assert (args.height % 16 == 0) and ( args.width % 16 @@ -1290,10 +864,6 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None: print(f'>>> Generate {n_frames} frames under each generation ...') noise_shape = [args.bs, channels, n_frames, h, w] - # Determine profiler iterations - profile_active_iters = getattr(args, 'profile_iterations', 3) - use_pytorch_profiler = profiler.enabled and profile_active_iters > 0 - # Start inference for idx in range(0, len(df)): sample = df.iloc[idx] @@ -1309,13 +879,12 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None: # Load transitions to get the initial state later transition_path = get_transition_path(args.prompt_dir, sample) - with profiler.profile_section("load_transitions"): - with h5py.File(transition_path, 'r') as h5f: - transition_dict = {} - for key in h5f.keys(): - transition_dict[key] = torch.tensor(h5f[key][()]) - for key in h5f.attrs.keys(): - transition_dict[key] = h5f.attrs[key] + with h5py.File(transition_path, 'r') as h5f: + transition_dict = {} + for key in h5f.keys(): + transition_dict[key] = torch.tensor(h5f[key][()]) + for key in h5f.attrs.keys(): + transition_dict[key] = h5f.attrs[key] # If many, test various frequence control and world-model generation for fs in args.frame_stride: @@ -1337,185 +906,142 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None: } # Obtain initial frame and state - with profiler.profile_section("prepare_init_input"): - start_idx = 0 - model_input_fs = ori_fps // fs - batch, ori_state_dim, ori_action_dim = prepare_init_input( - start_idx, - init_frame_path, - transition_dict, - fs, - data.test_datasets[args.dataset], - n_obs_steps=model.n_obs_steps_imagen) - observation = { - 'observation.images.top': - batch['observation.image'].permute(1, 0, 2, - 3)[-1].unsqueeze(0), - 'observation.state': - batch['observation.state'][-1].unsqueeze(0), - 'action': - torch.zeros_like(batch['action'][-1]).unsqueeze(0) - } - observation = _move_to_device(observation, device) + start_idx = 0 + model_input_fs = ori_fps // fs + batch, ori_state_dim, ori_action_dim = prepare_init_input( + start_idx, + init_frame_path, + transition_dict, + fs, + data.test_datasets[args.dataset], + n_obs_steps=model.n_obs_steps_imagen) + observation = { + 'observation.images.top': + batch['observation.image'].permute(1, 0, 2, + 3)[-1].unsqueeze(0), + 'observation.state': + batch['observation.state'][-1].unsqueeze(0), + 'action': + torch.zeros_like(batch['action'][-1]).unsqueeze(0) + } + observation = _move_to_device(observation, device) # Update observation queues cond_obs_queues = populate_queues(cond_obs_queues, observation) - # Setup PyTorch profiler context if enabled - pytorch_prof_ctx = nullcontext() - if use_pytorch_profiler: - pytorch_prof_ctx = profiler.start_pytorch_profiler( - wait=1, warmup=1, active=profile_active_iters - ) - # Multi-round interaction with the world-model - with pytorch_prof_ctx: - for itr in tqdm(range(args.n_iter)): - log_every = max(1, args.step_log_every) - log_step = (itr % log_every == 0) - profiler.current_iteration = itr - profiler.record_memory(f"iter_{itr}_start") + for itr in tqdm(range(args.n_iter)): + log_every = max(1, args.step_log_every) + log_step = (itr % log_every == 0) - with profiler.profile_section("iteration_total"): - # Get observation - with profiler.profile_section("prepare_observation"): - observation = { - 'observation.images.top': - torch.stack(list( - cond_obs_queues['observation.images.top']), - dim=1).permute(0, 2, 1, 3, 4), - 'observation.state': - torch.stack(list(cond_obs_queues['observation.state']), - dim=1), - 'action': - torch.stack(list(cond_obs_queues['action']), dim=1), - } - observation = _move_to_device(observation, device) + # Get observation + observation = { + 'observation.images.top': + torch.stack(list( + cond_obs_queues['observation.images.top']), + dim=1).permute(0, 2, 1, 3, 4), + 'observation.state': + torch.stack(list(cond_obs_queues['observation.state']), + dim=1), + 'action': + torch.stack(list(cond_obs_queues['action']), dim=1), + } + observation = _move_to_device(observation, device) - # Use world-model in policy to generate action - if log_step: - print(f'>>> Step {itr}: generating actions ...') - with profiler.profile_section("action_generation"): - pred_videos_0, pred_actions, _ = image_guided_synthesis_sim_mode( - model, - sample['instruction'], - observation, - noise_shape, - action_cond_step=args.exe_steps, - ddim_steps=args.ddim_steps, - ddim_eta=args.ddim_eta, - unconditional_guidance_scale=args. - unconditional_guidance_scale, - fs=model_input_fs, - timestep_spacing=args.timestep_spacing, - guidance_rescale=args.guidance_rescale, - sim_mode=False, - diffusion_autocast_dtype=diffusion_autocast_dtype) + # Use world-model in policy to generate action + if log_step: + print(f'>>> Step {itr}: generating actions ...') + pred_videos_0, pred_actions, _ = image_guided_synthesis_sim_mode( + model, + sample['instruction'], + observation, + noise_shape, + action_cond_step=args.exe_steps, + ddim_steps=args.ddim_steps, + ddim_eta=args.ddim_eta, + unconditional_guidance_scale=args. + unconditional_guidance_scale, + fs=model_input_fs, + timestep_spacing=args.timestep_spacing, + guidance_rescale=args.guidance_rescale, + sim_mode=False, + diffusion_autocast_dtype=diffusion_autocast_dtype) - # Update future actions in the observation queues - with profiler.profile_section("update_action_queues"): - for act_idx in range(len(pred_actions[0])): - obs_update = {'action': pred_actions[0][act_idx:act_idx + 1]} - obs_update['action'][:, ori_action_dim:] = 0.0 - cond_obs_queues = populate_queues(cond_obs_queues, - obs_update) + # Update future actions in the observation queues + for act_idx in range(len(pred_actions[0])): + obs_update = {'action': pred_actions[0][act_idx:act_idx + 1]} + obs_update['action'][:, ori_action_dim:] = 0.0 + cond_obs_queues = populate_queues(cond_obs_queues, + obs_update) - # Collect data for interacting the world-model using the predicted actions - with profiler.profile_section("prepare_wm_observation"): - observation = { - 'observation.images.top': - torch.stack(list( - cond_obs_queues['observation.images.top']), - dim=1).permute(0, 2, 1, 3, 4), - 'observation.state': - torch.stack(list(cond_obs_queues['observation.state']), - dim=1), - 'action': - torch.stack(list(cond_obs_queues['action']), dim=1), - } - observation = _move_to_device(observation, device) + # Collect data for interacting the world-model using the predicted actions + observation = { + 'observation.images.top': + torch.stack(list( + cond_obs_queues['observation.images.top']), + dim=1).permute(0, 2, 1, 3, 4), + 'observation.state': + torch.stack(list(cond_obs_queues['observation.state']), + dim=1), + 'action': + torch.stack(list(cond_obs_queues['action']), dim=1), + } + observation = _move_to_device(observation, device) - # Interaction with the world-model - if log_step: - print(f'>>> Step {itr}: interacting with world model ...') - with profiler.profile_section("world_model_interaction"): - pred_videos_1, _, pred_states = image_guided_synthesis_sim_mode( - model, - "", - observation, - noise_shape, - action_cond_step=args.exe_steps, - ddim_steps=args.ddim_steps, - ddim_eta=args.ddim_eta, - unconditional_guidance_scale=args. - unconditional_guidance_scale, - fs=model_input_fs, - text_input=False, - timestep_spacing=args.timestep_spacing, - guidance_rescale=args.guidance_rescale, - diffusion_autocast_dtype=diffusion_autocast_dtype) + # Interaction with the world-model + if log_step: + print(f'>>> Step {itr}: interacting with world model ...') + pred_videos_1, _, pred_states = image_guided_synthesis_sim_mode( + model, + "", + observation, + noise_shape, + action_cond_step=args.exe_steps, + ddim_steps=args.ddim_steps, + ddim_eta=args.ddim_eta, + unconditional_guidance_scale=args. + unconditional_guidance_scale, + fs=model_input_fs, + text_input=False, + timestep_spacing=args.timestep_spacing, + guidance_rescale=args.guidance_rescale, + diffusion_autocast_dtype=diffusion_autocast_dtype) - with profiler.profile_section("update_state_queues"): - for step_idx in range(args.exe_steps): - obs_update = { - 'observation.images.top': - pred_videos_1[0][:, step_idx:step_idx + 1].permute(1, 0, 2, 3), - 'observation.state': - torch.zeros_like(pred_states[0][step_idx:step_idx + 1]) if - args.zero_pred_state else pred_states[0][step_idx:step_idx + 1], - 'action': - torch.zeros_like(pred_actions[0][-1:]) - } - obs_update['observation.state'][:, ori_state_dim:] = 0.0 - cond_obs_queues = populate_queues(cond_obs_queues, - obs_update) + for step_idx in range(args.exe_steps): + obs_update = { + 'observation.images.top': + pred_videos_1[0][:, step_idx:step_idx + 1].permute(1, 0, 2, 3), + 'observation.state': + torch.zeros_like(pred_states[0][step_idx:step_idx + 1]) if + args.zero_pred_state else pred_states[0][step_idx:step_idx + 1], + 'action': + torch.zeros_like(pred_actions[0][-1:]) + } + obs_update['observation.state'][:, ori_state_dim:] = 0.0 + cond_obs_queues = populate_queues(cond_obs_queues, + obs_update) - # Save the imagen videos for decision-making (async) - with profiler.profile_section("save_results"): - sample_tag = f"{args.dataset}-vid{sample['videoid']}-dm-fs-{fs}/itr-{itr}" - log_to_tensorboard_async(writer, - pred_videos_0, - sample_tag, - fps=args.save_fps) - # Save videos environment changes via world-model interaction - sample_tag = f"{args.dataset}-vid{sample['videoid']}-wd-fs-{fs}/itr-{itr}" - log_to_tensorboard_async(writer, - pred_videos_1, - sample_tag, - fps=args.save_fps) + # Save the imagen videos for decision-making (async) + sample_video_file = f'{video_save_dir}/dm/{fs}/itr-{itr}.mp4' + save_results_async(pred_videos_0, + sample_video_file, + fps=args.save_fps) + # Save videos environment changes via world-model interaction + sample_video_file = f'{video_save_dir}/wm/{fs}/itr-{itr}.mp4' + save_results_async(pred_videos_1, + sample_video_file, + fps=args.save_fps) - # Save the imagen videos for decision-making - sample_video_file = f'{video_save_dir}/dm/{fs}/itr-{itr}.mp4' - save_results_async(pred_videos_0, - sample_video_file, - fps=args.save_fps) - # Save videos environment changes via world-model interaction - sample_video_file = f'{video_save_dir}/wm/{fs}/itr-{itr}.mp4' - save_results_async(pred_videos_1, - sample_video_file, - fps=args.save_fps) - - print('>' * 24) - # Collect the result of world-model interactions - wm_video.append(pred_videos_1[:, :, :args.exe_steps].cpu()) - - profiler.record_memory(f"iter_{itr}_end") - profiler.step_profiler() + print('>' * 24) + # Collect the result of world-model interactions + wm_video.append(pred_videos_1[:, :, :args.exe_steps].cpu()) full_video = torch.cat(wm_video, dim=2) - sample_tag = f"{args.dataset}-vid{sample['videoid']}-wd-fs-{fs}/full" - log_to_tensorboard_async(writer, - full_video, - sample_tag, - fps=args.save_fps) sample_full_video_file = f"{video_save_dir}/../{sample['videoid']}_full_fs{fs}.mp4" save_results_async(full_video, sample_full_video_file, fps=args.save_fps) - # Wait for all async I/O to complete before profiling report + # Wait for all async I/O to complete _flush_io() - # Save profiling results - profiler.save_results() - def get_parser(): parser = argparse.ArgumentParser() @@ -1704,32 +1230,6 @@ def get_parser(): type=int, default=8, help="fps for the saving video") - # Profiling arguments - parser.add_argument( - "--profile", - action='store_true', - default=False, - help="Enable performance profiling (macro and operator-level analysis)." - ) - parser.add_argument( - "--profile_output_dir", - type=str, - default=None, - help="Directory to save profiling results. Defaults to {savedir}/profile_output." - ) - parser.add_argument( - "--profile_iterations", - type=int, - default=3, - help="Number of iterations to run PyTorch profiler's active phase for operator-level analysis." - ) - parser.add_argument( - "--profile_detail", - type=str, - choices=["light", "full"], - default="light", - help="Profiling detail level. Use 'full' for shapes/stacks/memory/flops." - ) return parser @@ -1741,15 +1241,5 @@ if __name__ == '__main__': seed = random.randint(0, 2**31) seed_everything(seed) - # Initialize profiler - profile_output_dir = args.profile_output_dir - if profile_output_dir is None: - profile_output_dir = os.path.join(args.savedir, "profile_output") - init_profiler( - enabled=args.profile, - output_dir=profile_output_dir, - profile_detail=args.profile_detail, - ) - rank, gpu_num = 0, 1 run_inference(args, gpu_num, rank)