diff --git a/profile_unet_flops.md b/profile_unet_flops.md index 2ab4024..6dcdd6c 100644 --- a/profile_unet_flops.md +++ b/profile_unet_flops.md @@ -118,4 +118,100 @@ SUMMARY Matmul throughput: 33.11 TFLOPS/s (54.3% of BF16 peak) Overall throughput: 18.87 TFLOPS/s (30.9% of BF16 peak) GPU peak (BF16): 61.0 TFLOPS -(unifolm-wma) ASC1637@wx-ms-w7900d-0033:/mnt/ASC1637/unifolm-world-model-action$ \ No newline at end of file +(unifolm-wma) ASC1637@wx-ms-w7900d-0033:/mnt/ASC1637/unifolm-world-model-action$ + +======================================================================== +TABLE 1: STAGE TIMING +======================================================================== +Stage Mean(ms) Std % +------------------------------------------------------------------------ +1_Image_Embedding 29.5 0.16 0.1% +2_VAE_Encode 51.3 0.06 0.1% +3_Text_Conditioning 14.7 0.18 0.0% +4_Projectors 0.2 0.03 0.0% +5_DDIM_Loop 33392.5 3.21 97.3% +6_VAE_Decode 808.4 1.00 2.4% +7_Post_Process 15.8 0.56 0.0% +------------------------------------------------------------------------ +TOTAL 34312.4 + +================================================================================ +TABLE 2: UNET SUB-MODULE BREAKDOWN +================================================================================ +Module Type Total(ms) Count Per-call % +-------------------------------------------------------------------------------- +ResBlock 10256.3 1100 9.32 23.2% +SpatialTransformer 9228.2 800 11.54 20.9% +CrossAttention 8105.8 3300 2.46 18.3% +ConditionalUnet1D 6409.5 100 64.10 14.5% +TemporalTransformer 5847.0 850 6.88 13.2% +FeedForward 4338.1 1650 2.63 9.8% +UNet.out 73.8 50 1.48 0.2% +-------------------------------------------------------------------------------- +TOTAL (hooked) 44258.7 + +========================================================================================== +TABLE 2b: PER-UNET-BLOCK TIMING (coarse modules, per DDIM loop) +========================================================================================== +Block Total(ms) % Breakdown +------------------------------------------------------------------------------------------ +input_blocks.1 3376.2 7.6% SpatialTransformer=1101, CrossAttention=990, ResBlock=543, TemporalTransformer=454, FeedForward=288 +input_blocks.2 3374.0 7.6% SpatialTransformer=1100, CrossAttention=991, ResBlock=540, TemporalTransformer=455, FeedForward=288 +input_blocks.4 1592.4 3.6% SpatialTransformer=394, ResBlock=374, CrossAttention=303, TemporalTransformer=272, FeedForward=249 +input_blocks.5 1642.5 3.7% ResBlock=425, SpatialTransformer=397, CrossAttention=303, TemporalTransformer=271, FeedForward=247 +input_blocks.7 1469.0 3.3% ResBlock=416, SpatialTransformer=324, FeedForward=251, CrossAttention=240, TemporalTransformer=237 +input_blocks.8 1543.7 3.5% ResBlock=491, SpatialTransformer=325, FeedForward=250, CrossAttention=240, TemporalTransformer=238 +input_blocks.10 217.5 0.5% ResBlock=218 +input_blocks.11 216.8 0.5% ResBlock=217 +middle_block 848.9 1.9% ResBlock=434, SpatialTransformer=151, CrossAttention=134, TemporalTransformer=69, FeedForward=61 +output_blocks.0 303.2 0.7% ResBlock=303 +output_blocks.1 303.1 0.7% ResBlock=303 +output_blocks.2 302.8 0.7% ResBlock=303 +output_blocks.3 1734.8 3.9% ResBlock=687, SpatialTransformer=322, FeedForward=249, CrossAttention=239, TemporalTransformer=237 +output_blocks.4 1739.8 3.9% ResBlock=688, SpatialTransformer=323, FeedForward=251, CrossAttention=239, TemporalTransformer=238 +output_blocks.5 1622.3 3.7% ResBlock=570, SpatialTransformer=324, FeedForward=251, CrossAttention=239, TemporalTransformer=238 +output_blocks.6 1881.0 4.3% ResBlock=664, SpatialTransformer=393, CrossAttention=301, TemporalTransformer=272, FeedForward=250 +output_blocks.7 1768.0 4.0% ResBlock=554, SpatialTransformer=393, CrossAttention=301, TemporalTransformer=272, FeedForward=249 +output_blocks.8 1688.7 3.8% ResBlock=474, SpatialTransformer=393, CrossAttention=301, TemporalTransformer=272, FeedForward=249 +output_blocks.9 3558.6 8.0% SpatialTransformer=1096, CrossAttention=992, ResBlock=727, TemporalTransformer=454, FeedForward=290 +output_blocks.10 3492.8 7.9% SpatialTransformer=1096, CrossAttention=992, ResBlock=662, TemporalTransformer=454, FeedForward=289 +output_blocks.11 3493.3 7.9% SpatialTransformer=1096, CrossAttention=992, ResBlock=662, TemporalTransformer=454, FeedForward=289 +out 73.8 0.2% UNet.out=74 +action_unet 3212.0 7.3% ConditionalUnet1D=3212 +state_unet 3197.6 7.2% ConditionalUnet1D=3198 +other 1606.2 3.6% TemporalTransformer=960, FeedForward=337, CrossAttention=309 +------------------------------------------------------------------------------------------ +TOTAL 44258.7 + +====================================================================== +TABLE 2c: ATTENTION vs FEEDFORWARD (deep hooks) +====================================================================== +Component Total(ms) % +---------------------------------------------------------------------- +CrossAttention 8105.8 65.1% +FeedForward 4338.1 34.9% +---------------------------------------------------------------------- +TOTAL (attn+ff) 12443.9 + +================================================== +TABLE 3: MEMORY SUMMARY +================================================== + Initial allocated: 11.82 GB + Peak allocated: 14.43 GB + Delta (pipeline): 2.61 GB + +============================================================ +TABLE 4: THROUGHPUT +============================================================ + Total pipeline latency: 34312.4 ms + DDIM loop latency: 33392.5 ms + DDIM steps: 50 + CFG scale: 1.0 (1x UNet/step) + UNet forward calls: 50 + Per DDIM step: 667.9 ms + Per UNet forward: 667.9 ms + VAE encode bandwidth: 0.1 GB/s (peak HBM: 864.0 GB/s) + VAE decode bandwidth: 0.0 GB/s (peak HBM: 864.0 GB/s) + GPU BF16 peak: 61.0 TFLOPS + +Done. \ No newline at end of file diff --git a/scripts/evaluation/profile_pipeline.py b/scripts/evaluation/profile_pipeline.py new file mode 100644 index 0000000..31e078c --- /dev/null +++ b/scripts/evaluation/profile_pipeline.py @@ -0,0 +1,733 @@ +""" +Profile the full inference pipeline of the world model, covering all 7 stages: + 1. Image Embedding + 2. VAE Encode + 3. Text Conditioning + 4. State/Action Projectors + 5. DDIM Loop + 6. VAE Decode + 7. Post-process + +Reports stage-level timing, UNet sub-module breakdown, memory summary, +and throughput analysis. +TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 CUDA_VISIBLE_DEVICES=0 python scripts/evaluation/profile_pipeline.py --ckpt_path ckpts/unifolm_wma_dual_mix_bf16.ckpt --config configs/inference/world_model_interaction.yaml --ddim_steps 50 --cfg_scale 1.0 --n_runs 3 --deep +Usage: + TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 CUDA_VISIBLE_DEVICES=0 python scripts/evaluation/profile_pipeline.py --ckpt_path ckpts/unifolm_wma_dual_mix_bf16.ckpt --config configs/inference/world_model_interaction.yaml --ddim_steps 50 --cfg_scale 1.0 --n_runs 3 +""" + +import argparse +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint # must be loaded before unifolm_wma.utils.common +from contextlib import nullcontext, contextmanager +from collections import defaultdict +from omegaconf import OmegaConf +from einops import rearrange, repeat + +from unifolm_wma.utils.utils import instantiate_from_config +from unifolm_wma.models.samplers.ddim import DDIMSampler +from unifolm_wma.modules.attention import ( + SpatialTransformer, TemporalTransformer, + BasicTransformerBlock, CrossAttention, FeedForward, +) +from unifolm_wma.modules.networks.wma_model import ResBlock +from unifolm_wma.models.diffusion_head.conditional_unet1d import ConditionalUnet1D + +# --- W7900D theoretical peak --- +PEAK_BF16_TFLOPS = 61.0 +MEM_BW_GBS = 864.0 + + +# --------------------------------------------------------------------------- +# Utility: patch norms to bypass autocast fp32 promotion +# --------------------------------------------------------------------------- +def patch_norm_bypass_autocast(): + """Monkey-patch GroupNorm and LayerNorm to bypass autocast's fp32 policy.""" + + def _group_norm_forward(self, x): + with torch.amp.autocast('cuda', enabled=False): + return F.group_norm( + x, self.num_groups, + self.weight.to(x.dtype) if self.weight is not None else None, + self.bias.to(x.dtype) if self.bias is not None else None, + self.eps) + + def _layer_norm_forward(self, x): + with torch.amp.autocast('cuda', enabled=False): + return F.layer_norm( + x, self.normalized_shape, + self.weight.to(x.dtype) if self.weight is not None else None, + self.bias.to(x.dtype) if self.bias is not None else None, + self.eps) + + torch.nn.GroupNorm.forward = _group_norm_forward + torch.nn.LayerNorm.forward = _layer_norm_forward + + +# --------------------------------------------------------------------------- +# Utility: torch.compile hot ResBlocks +# --------------------------------------------------------------------------- +def apply_torch_compile(model, hot_indices=(5, 8, 9)): + """Compile ResBlock._forward in the hottest output_blocks for operator fusion.""" + unet = model.model.diffusion_model + compiled = 0 + for idx in hot_indices: + block = unet.output_blocks[idx] + for layer in block: + if isinstance(layer, ResBlock): + layer._forward = torch.compile(layer._forward, mode="default") + compiled += 1 + print(f" torch.compile: {compiled} ResBlocks in output_blocks{list(hot_indices)}") + + +# --------------------------------------------------------------------------- +# Model loading +# --------------------------------------------------------------------------- +def load_model(args): + config = OmegaConf.load(args.config) + config['model']['params']['wma_config']['params']['use_checkpoint'] = False + model = instantiate_from_config(config.model) + + state_dict = torch.load(args.ckpt_path, map_location="cpu") + if "state_dict" in state_dict: + state_dict = state_dict["state_dict"] + model.load_state_dict(state_dict, strict=True) + + model.eval() + model.model.to(torch.bfloat16) + model.diffusion_autocast_dtype = torch.bfloat16 + apply_torch_compile(model) + model = model.cuda() + return model + + +# --------------------------------------------------------------------------- +# CudaTimer — precise GPU timing via CUDA events +# --------------------------------------------------------------------------- +class CudaTimer: + """Context manager for GPU-precise stage timing using CUDA events.""" + + def __init__(self, name, records): + self.name = name + self.records = records + self.start = torch.cuda.Event(enable_timing=True) + self.end = torch.cuda.Event(enable_timing=True) + + def __enter__(self): + torch.cuda.synchronize() + self.start.record() + return self + + def __exit__(self, *args): + self.end.record() + torch.cuda.synchronize() + elapsed = self.start.elapsed_time(self.end) + self.records[self.name].append(elapsed) + + +# --------------------------------------------------------------------------- +# HookProfiler — sub-module level timing inside UNet via hooks +# --------------------------------------------------------------------------- +class HookProfiler: + """Register forward hooks on UNet sub-modules to collect per-call timing.""" + + # Coarse-grained targets (original) + COARSE_CLASSES = ( + SpatialTransformer, + TemporalTransformer, + ResBlock, + ConditionalUnet1D, + ) + + # Fine-grained targets for deep DDIM analysis + FINE_CLASSES = ( + CrossAttention, + FeedForward, + ) + + def __init__(self, unet, deep=False): + self.unet = unet + self.deep = deep + self.handles = [] + # per-instance data: {instance_id: [(start_event, end_event), ...]} + self._events = defaultdict(list) + # tag mapping: {instance_id: (class_name, module_name)} + self._tags = {} + # block location: {instance_id: block_location_str} + self._block_loc = {} + + @staticmethod + def _get_block_location(name): + """Derive UNet block location from module name, e.g. 'input_blocks.3.1'.""" + parts = name.split('.') + if len(parts) >= 2 and parts[0] == 'input_blocks': + return f"input_blocks.{parts[1]}" + elif len(parts) >= 1 and parts[0] == 'middle_block': + return "middle_block" + elif len(parts) >= 2 and parts[0] == 'output_blocks': + return f"output_blocks.{parts[1]}" + elif 'action_unet' in name: + return "action_unet" + elif 'state_unet' in name: + return "state_unet" + elif name == 'out' or name.startswith('out.'): + return "out" + return "other" + + def register(self): + """Attach pre/post forward hooks to target sub-modules + unet.out.""" + target_classes = self.COARSE_CLASSES + if self.deep: + target_classes = target_classes + self.FINE_CLASSES + + for name, mod in self.unet.named_modules(): + if isinstance(mod, target_classes): + tag = type(mod).__name__ + inst_id = id(mod) + self._tags[inst_id] = (tag, name) + self._block_loc[inst_id] = self._get_block_location(name) + self.handles.append( + mod.register_forward_pre_hook(self._make_pre_hook(inst_id))) + self.handles.append( + mod.register_forward_hook(self._make_post_hook(inst_id))) + + # Also hook unet.out (nn.Sequential) + out_mod = self.unet.out + inst_id = id(out_mod) + self._tags[inst_id] = ("UNet.out", "out") + self._block_loc[inst_id] = "out" + self.handles.append( + out_mod.register_forward_pre_hook(self._make_pre_hook(inst_id))) + self.handles.append( + out_mod.register_forward_hook(self._make_post_hook(inst_id))) + + def _make_pre_hook(self, inst_id): + events = self._events + + def hook(module, input): + start = torch.cuda.Event(enable_timing=True) + start.record() + events[inst_id].append([start, None]) + return hook + + def _make_post_hook(self, inst_id): + events = self._events + + def hook(module, input, output): + end = torch.cuda.Event(enable_timing=True) + end.record() + events[inst_id][-1][1] = end + return hook + + def reset(self): + """Clear collected events for a fresh run.""" + self._events.clear() + + def synchronize_and_collect(self): + """Sync GPU and compute elapsed times. Returns (by_type, by_instance, by_block).""" + torch.cuda.synchronize() + by_type = defaultdict(lambda: {"total_ms": 0.0, "count": 0, "calls": []}) + by_instance = {} + # by_block: {block_loc: {tag: {"total_ms", "count"}}} + by_block = defaultdict(lambda: defaultdict(lambda: {"total_ms": 0.0, "count": 0})) + + for inst_id, pairs in self._events.items(): + tag, mod_name = self._tags[inst_id] + block_loc = self._block_loc.get(inst_id, "other") + inst_times = [] + for start_evt, end_evt in pairs: + if end_evt is not None: + ms = start_evt.elapsed_time(end_evt) + inst_times.append(ms) + by_type[tag]["total_ms"] += ms + by_type[tag]["count"] += 1 + by_type[tag]["calls"].append(ms) + by_block[block_loc][tag]["total_ms"] += ms + by_block[block_loc][tag]["count"] += 1 + by_instance[(tag, mod_name)] = inst_times + + return dict(by_type), by_instance, dict(by_block) + + def remove(self): + """Remove all hooks.""" + for h in self.handles: + h.remove() + self.handles.clear() + + +# --------------------------------------------------------------------------- +# Build dummy inputs matching the pipeline's expected shapes +# --------------------------------------------------------------------------- +def build_dummy_inputs(model, noise_shape): + """Create synthetic observation dict and prompts for profiling.""" + device = next(model.parameters()).device + B, C, T, H, W = noise_shape + dtype = torch.bfloat16 + + # observation.images.top: [B, C, O, H, W] (permuted to [B,O,C,H,W] inside pipeline) + O = 2 + obs_images = torch.randn(B, 3, O, 320, 512, device=device, dtype=dtype) + obs_state = torch.randn(B, O, 16, device=device, dtype=dtype) + action = torch.randn(B, 16, 16, device=device, dtype=dtype) + + observation = { + 'observation.images.top': obs_images, + 'observation.state': obs_state, + 'action': action, + } + prompts = ["a robot arm performing a task"] * B + return observation, prompts + + +# --------------------------------------------------------------------------- +# Run one full pipeline pass with per-stage timing +# --------------------------------------------------------------------------- +def run_pipeline(model, observation, prompts, noise_shape, ddim_steps, + cfg_scale, hook_profiler): + """Execute the full 7-stage pipeline, returning per-stage timing dict.""" + records = defaultdict(list) + device = next(model.parameters()).device + B, C, T, H, W = noise_shape + dtype = torch.bfloat16 + + fs = torch.tensor([1] * B, dtype=torch.long, device=device) + + # --- Stage 1: Image Embedding --- + with CudaTimer("1_Image_Embedding", records): + img = observation['observation.images.top'].permute(0, 2, 1, 3, 4) + cond_img = rearrange(img, 'b o c h w -> (b o) c h w')[-1:].to(dtype=dtype) + with torch.autocast('cuda', dtype=torch.bfloat16): + cond_img_emb = model.embedder(cond_img) + cond_img_emb = model.image_proj_model(cond_img_emb) + + # --- Stage 2: VAE Encode --- + with CudaTimer("2_VAE_Encode", records): + videos = img.permute(0, 2, 1, 3, 4) # [B, C, O, H, W] + b_v, c_v, t_v, h_v, w_v = videos.shape + vae_dtype = next(model.first_stage_model.parameters()).dtype + x_vae = rearrange(videos, 'b c t h w -> (b t) c h w').to(dtype=vae_dtype) + z = model.encode_first_stage(x_vae) + z = rearrange(z, '(b t) c h w -> b c t h w', b=b_v, t=t_v) + img_cat_cond = z[:, :, -1:, :, :] + img_cat_cond = repeat(img_cat_cond, + 'b c t h w -> b c (repeat t) h w', repeat=T) + cond = {"c_concat": [img_cat_cond]} + + vae_enc_input_bytes = x_vae.nelement() * x_vae.element_size() + vae_enc_output_bytes = z.nelement() * z.element_size() + + # --- Stage 3: Text Conditioning --- + with CudaTimer("3_Text_Conditioning", records): + cond_ins_emb = model.get_learned_conditioning(prompts) + + # --- Stage 4: State/Action Projectors --- + with CudaTimer("4_Projectors", records): + projector_dtype = next(model.state_projector.parameters()).dtype + with torch.autocast('cuda', dtype=torch.bfloat16): + cond_state_emb = model.state_projector( + observation['observation.state'].to(dtype=projector_dtype)) + cond_state_emb = cond_state_emb + model.agent_state_pos_emb + + cond_action_emb = model.action_projector( + observation['action'].to(dtype=projector_dtype)) + cond_action_emb = cond_action_emb + model.agent_action_pos_emb + + # Assemble cross-attention conditioning + cond["c_crossattn"] = [ + torch.cat([cond_state_emb, cond_action_emb, cond_ins_emb, cond_img_emb], + dim=1) + ] + n_obs_acting = getattr(model, 'n_obs_steps_acting', 2) + cond["c_crossattn_action"] = [ + observation['observation.images.top'][:, :, -n_obs_acting:], + observation['observation.state'][:, -n_obs_acting:], + True, # sim_mode + False, + ] + + # CFG: build unconditional conditioning if needed + uc = None + if cfg_scale != 1.0: + uc_crossattn = torch.zeros_like(cond["c_crossattn"][0]) + uc = { + "c_concat": cond["c_concat"], + "c_crossattn": [uc_crossattn], + "c_crossattn_action": cond["c_crossattn_action"], + } + + # --- Stage 5: DDIM Loop --- + ddim_sampler = DDIMSampler(model) + hook_profiler.reset() + + with CudaTimer("5_DDIM_Loop", records): + with torch.autocast('cuda', dtype=torch.bfloat16): + samples, actions, states, _ = ddim_sampler.sample( + S=ddim_steps, + conditioning=cond, + batch_size=B, + shape=noise_shape[1:], + verbose=False, + unconditional_guidance_scale=cfg_scale, + unconditional_conditioning=uc, + eta=1.0, + cfg_img=None, + mask=None, + x0=None, + fs=fs, + timestep_spacing='uniform', + guidance_rescale=0.0, + unconditional_conditioning_img_nonetext=None, + ) + + hook_by_type, hook_by_instance, hook_by_block = hook_profiler.synchronize_and_collect() + + # --- Stage 6: VAE Decode --- + with CudaTimer("6_VAE_Decode", records): + batch_images = model.decode_first_stage(samples) + + vae_dec_input_bytes = samples.nelement() * samples.element_size() + vae_dec_output_bytes = batch_images.nelement() * batch_images.element_size() + + # --- Stage 7: Post-process --- + with CudaTimer("7_Post_Process", records): + batch_images_cpu = batch_images.cpu() + actions_cpu = actions.cpu() + states_cpu = states.cpu() + # Simulate video save overhead: clamp + uint8 conversion + _ = (batch_images_cpu.clamp(-1, 1) * 127.5 + 127.5).to(torch.uint8) + + # Flatten single-element lists + stage_times = {k: v[0] for k, v in records.items()} + + bandwidth_info = { + "vae_enc_input_bytes": vae_enc_input_bytes, + "vae_enc_output_bytes": vae_enc_output_bytes, + "vae_dec_input_bytes": vae_dec_input_bytes, + "vae_dec_output_bytes": vae_dec_output_bytes, + } + + return stage_times, hook_by_type, hook_by_instance, hook_by_block, bandwidth_info + + +# --------------------------------------------------------------------------- +# Reporting +# --------------------------------------------------------------------------- +def print_stage_timing(all_runs_stages): + """Table 1: Stage Timing — name | mean(ms) | std | percent.""" + import numpy as np + stage_names = list(all_runs_stages[0].keys()) + means = {} + stds = {} + for name in stage_names: + vals = [run[name] for run in all_runs_stages] + means[name] = np.mean(vals) + stds[name] = np.std(vals) + total = sum(means.values()) + + print() + print("=" * 72) + print("TABLE 1: STAGE TIMING") + print("=" * 72) + print(f"{'Stage':<25} {'Mean(ms)':>10} {'Std':>10} {'%':>8}") + print("-" * 72) + for name in stage_names: + pct = means[name] / total * 100 if total > 0 else 0 + print(f"{name:<25} {means[name]:>10.1f} {stds[name]:>10.2f} {pct:>7.1f}%") + print("-" * 72) + print(f"{'TOTAL':<25} {total:>10.1f}") + print() + + +def print_unet_breakdown(all_runs_hooks): + """Table 2: UNet Sub-Module Breakdown — type | total | count | per-call | percent.""" + import numpy as np + # Aggregate across runs + agg = defaultdict(lambda: {"totals": [], "counts": []}) + for hook_by_type in all_runs_hooks: + for tag, data in hook_by_type.items(): + agg[tag]["totals"].append(data["total_ms"]) + agg[tag]["counts"].append(data["count"]) + + print("=" * 80) + print("TABLE 2: UNET SUB-MODULE BREAKDOWN") + print("=" * 80) + print(f"{'Module Type':<25} {'Total(ms)':>10} {'Count':>7} {'Per-call':>10} {'%':>8}") + print("-" * 80) + + grand_total = 0 + rows = [] + for tag, d in agg.items(): + mean_total = np.mean(d["totals"]) + mean_count = np.mean(d["counts"]) + per_call = mean_total / mean_count if mean_count > 0 else 0 + grand_total += mean_total + rows.append((tag, mean_total, mean_count, per_call)) + + rows.sort(key=lambda r: r[1], reverse=True) + for tag, mean_total, mean_count, per_call in rows: + pct = mean_total / grand_total * 100 if grand_total > 0 else 0 + print(f"{tag:<25} {mean_total:>10.1f} {int(mean_count):>7} {per_call:>10.2f} {pct:>7.1f}%") + print("-" * 80) + print(f"{'TOTAL (hooked)':<25} {grand_total:>10.1f}") + print() + + +def print_block_timing(all_runs_blocks): + """Table 2b: Per-UNet-block timing — which blocks are hottest.""" + import numpy as np + # Aggregate: {block_loc: {tag: [total_ms_per_run, ...]}} + agg = defaultdict(lambda: defaultdict(list)) + for by_block in all_runs_blocks: + for block_loc, tag_dict in by_block.items(): + for tag, data in tag_dict.items(): + agg[block_loc][tag].append(data["total_ms"]) + + # Compute per-block totals + block_totals = {} + for block_loc, tag_dict in agg.items(): + block_totals[block_loc] = sum(np.mean(v) for v in tag_dict.values()) + + grand_total = sum(block_totals.values()) + + # Sort blocks in logical order + def block_sort_key(name): + if name.startswith("input_blocks."): + return (0, int(name.split('.')[1])) + elif name == "middle_block": + return (1, 0) + elif name.startswith("output_blocks."): + return (2, int(name.split('.')[1])) + elif name == "out": + return (3, 0) + elif name == "action_unet": + return (4, 0) + elif name == "state_unet": + return (5, 0) + return (9, 0) + + sorted_blocks = sorted(block_totals.keys(), key=block_sort_key) + + print("=" * 90) + print("TABLE 2b: PER-UNET-BLOCK TIMING (coarse modules, per DDIM loop)") + print("=" * 90) + print(f"{'Block':<22} {'Total(ms)':>10} {'%':>7} Breakdown") + print("-" * 90) + + for block_loc in sorted_blocks: + total = block_totals[block_loc] + pct = total / grand_total * 100 if grand_total > 0 else 0 + # Build breakdown string + parts = [] + for tag, vals in sorted(agg[block_loc].items(), + key=lambda x: np.mean(x[1]), reverse=True): + parts.append(f"{tag}={np.mean(vals):.0f}") + breakdown = ", ".join(parts) + print(f"{block_loc:<22} {total:>10.1f} {pct:>6.1f}% {breakdown}") + + print("-" * 90) + print(f"{'TOTAL':<22} {grand_total:>10.1f}") + print() + + +def print_attn_ff_breakdown(all_runs_hooks): + """Table 2c: CrossAttention vs FeedForward breakdown (--deep mode).""" + import numpy as np + agg = defaultdict(list) + for hook_by_type in all_runs_hooks: + for tag, data in hook_by_type.items(): + if tag in ("CrossAttention", "FeedForward"): + agg[tag].append(data["total_ms"]) + + if not agg: + return + + print("=" * 70) + print("TABLE 2c: ATTENTION vs FEEDFORWARD (deep hooks)") + print("=" * 70) + print(f"{'Component':<25} {'Total(ms)':>10} {'%':>8}") + print("-" * 70) + + grand = 0 + rows = [] + for tag in ("CrossAttention", "FeedForward"): + if tag in agg: + mean_t = np.mean(agg[tag]) + grand += mean_t + rows.append((tag, mean_t)) + + for tag, mean_t in rows: + pct = mean_t / grand * 100 if grand > 0 else 0 + print(f"{tag:<25} {mean_t:>10.1f} {pct:>7.1f}%") + print("-" * 70) + print(f"{'TOTAL (attn+ff)':<25} {grand:>10.1f}") + print() + + +def print_unet_detailed(all_runs_instances): + """Print per-instance UNet sub-module detail (--detailed mode).""" + import numpy as np + # Use last run's data + by_instance = all_runs_instances[-1] + print("=" * 100) + print("DETAILED: PER-INSTANCE UNET SUB-MODULE TIMING (last run)") + print("=" * 100) + print(f"{'Type':<22} {'Module Name':<45} {'Calls':>6} {'Total(ms)':>10} {'Mean(ms)':>10}") + print("-" * 100) + + rows = [] + for (tag, mod_name), times in by_instance.items(): + if len(times) == 0: + continue + total = sum(times) + mean = np.mean(times) + rows.append((tag, mod_name, len(times), total, mean)) + rows.sort(key=lambda r: r[3], reverse=True) + + for tag, mod_name, count, total, mean in rows: + short_name = mod_name[-42:] if len(mod_name) > 42 else mod_name + print(f"{tag:<22} {short_name:<45} {count:>6} {total:>10.2f} {mean:>10.3f}") + print() + + +def print_memory_summary(mem_before, mem_peak): + """Table 3: Memory Summary.""" + delta = mem_peak - mem_before + print("=" * 50) + print("TABLE 3: MEMORY SUMMARY") + print("=" * 50) + print(f" Initial allocated: {mem_before / 1e9:.2f} GB") + print(f" Peak allocated: {mem_peak / 1e9:.2f} GB") + print(f" Delta (pipeline): {delta / 1e9:.2f} GB") + print() + + +def print_throughput(all_runs_stages, all_bw, ddim_steps, cfg_scale): + """Table 4: Throughput — total latency, per-step, per-UNet-forward, VAE bandwidth.""" + import numpy as np + n_runs = len(all_runs_stages) + + # Total latency + totals = [] + for run in all_runs_stages: + totals.append(sum(run.values())) + mean_total = np.mean(totals) + + # DDIM loop time + ddim_times = [run["5_DDIM_Loop"] for run in all_runs_stages] + mean_ddim = np.mean(ddim_times) + + unet_calls = ddim_steps if cfg_scale == 1.0 else ddim_steps * 2 + per_step = mean_ddim / ddim_steps + per_unet = mean_ddim / unet_calls + + # VAE bandwidth + mean_enc_time = np.mean([run["2_VAE_Encode"] for run in all_runs_stages]) + mean_dec_time = np.mean([run["6_VAE_Decode"] for run in all_runs_stages]) + + bw = all_bw[-1] # use last run's byte counts + enc_bytes = bw["vae_enc_input_bytes"] + bw["vae_enc_output_bytes"] + dec_bytes = bw["vae_dec_input_bytes"] + bw["vae_dec_output_bytes"] + enc_bw = enc_bytes / (mean_enc_time / 1000) / 1e9 if mean_enc_time > 0 else 0 + dec_bw = dec_bytes / (mean_dec_time / 1000) / 1e9 if mean_dec_time > 0 else 0 + + print("=" * 60) + print("TABLE 4: THROUGHPUT") + print("=" * 60) + print(f" Total pipeline latency: {mean_total:.1f} ms") + print(f" DDIM loop latency: {mean_ddim:.1f} ms") + print(f" DDIM steps: {ddim_steps}") + print(f" CFG scale: {cfg_scale} ({'2x UNet/step' if cfg_scale != 1.0 else '1x UNet/step'})") + print(f" UNet forward calls: {unet_calls}") + print(f" Per DDIM step: {per_step:.1f} ms") + print(f" Per UNet forward: {per_unet:.1f} ms") + print(f" VAE encode bandwidth: {enc_bw:.1f} GB/s (peak HBM: {MEM_BW_GBS} GB/s)") + print(f" VAE decode bandwidth: {dec_bw:.1f} GB/s (peak HBM: {MEM_BW_GBS} GB/s)") + print(f" GPU BF16 peak: {PEAK_BF16_TFLOPS} TFLOPS") + print() + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- +def main(): + patch_norm_bypass_autocast() + + parser = argparse.ArgumentParser( + description="Profile the full inference pipeline") + parser.add_argument("--ckpt_path", type=str, required=True) + parser.add_argument("--config", type=str, required=True) + parser.add_argument("--ddim_steps", type=int, default=50) + parser.add_argument("--cfg_scale", type=float, default=1.0) + parser.add_argument("--n_runs", type=int, default=3) + parser.add_argument("--warmup", type=int, default=1) + parser.add_argument("--detailed", action="store_true", + help="Print per-instance UNet sub-module detail") + parser.add_argument("--deep", action="store_true", + help="Enable deep DDIM analysis: per-block, attn vs ff") + args = parser.parse_args() + + noise_shape = [1, 4, 16, 40, 64] + + # --- Load model --- + print("Loading model...") + model = load_model(args) + observation, prompts = build_dummy_inputs(model, noise_shape) + + # --- Setup hook profiler --- + unet = model.model.diffusion_model + hook_profiler = HookProfiler(unet, deep=args.deep) + hook_profiler.register() + print(f"Registered hooks on {len(hook_profiler.handles)} sub-modules") + + # --- Warmup --- + print(f"Warmup: {args.warmup} run(s)...") + with torch.no_grad(): + for i in range(args.warmup): + run_pipeline(model, observation, prompts, noise_shape, + args.ddim_steps, args.cfg_scale, hook_profiler) + print(f" warmup {i+1}/{args.warmup} done") + + # --- Measurement runs --- + print(f"Measuring: {args.n_runs} run(s)...") + torch.cuda.reset_peak_memory_stats() + mem_before = torch.cuda.memory_allocated() + + all_stages = [] + all_hooks = [] + all_instances = [] + all_blocks = [] + all_bw = [] + + with torch.no_grad(): + for i in range(args.n_runs): + stage_times, hook_by_type, hook_by_instance, hook_by_block, bw = run_pipeline( + model, observation, prompts, noise_shape, + args.ddim_steps, args.cfg_scale, hook_profiler) + all_stages.append(stage_times) + all_hooks.append(hook_by_type) + all_instances.append(hook_by_instance) + all_blocks.append(hook_by_block) + all_bw.append(bw) + total = sum(stage_times.values()) + print(f" run {i+1}/{args.n_runs}: {total:.1f} ms total") + + mem_peak = torch.cuda.max_memory_allocated() + + # --- Reports --- + print_stage_timing(all_stages) + print_unet_breakdown(all_hooks) + print_block_timing(all_blocks) + if args.deep: + print_attn_ff_breakdown(all_hooks) + if args.detailed: + print_unet_detailed(all_instances) + print_memory_summary(mem_before, mem_peak) + print_throughput(all_stages, all_bw, args.ddim_steps, args.cfg_scale) + + hook_profiler.remove() + print("Done.") + + +if __name__ == "__main__": + main() diff --git a/src/unifolm_wma/modules/attention.py b/src/unifolm_wma/modules/attention.py index 5b7e1b7..a126396 100644 --- a/src/unifolm_wma/modules/attention.py +++ b/src/unifolm_wma/modules/attention.py @@ -86,9 +86,8 @@ class CrossAttention(nn.Module): self.relative_position_v = RelativePosition( num_units=dim_head, max_relative_position=temporal_length) else: - ## only used for spatial attention, while NOT for temporal attention - if XFORMERS_IS_AVAILBLE and temporal_length is None: - self.forward = self.efficient_forward + ## bmm fused-scale attention for all non-relative-position cases + self.forward = self.bmm_forward self.video_length = video_length self.image_cross_attention = image_cross_attention @@ -234,6 +233,119 @@ class CrossAttention(nn.Module): return self.to_out(out) + def bmm_forward(self, x, context=None, mask=None): + spatial_self_attn = (context is None) + k_ip, v_ip, out_ip = None, None, None + k_as, v_as, out_as = None, None, None + k_aa, v_aa, out_aa = None, None, None + + h = self.heads + q = self.to_q(x) + context = default(context, x) + + if self.image_cross_attention and not spatial_self_attn: + context_agent_state = context[:, :self.agent_state_context_len, :] + context_agent_action = context[:, + self.agent_state_context_len:self. + agent_state_context_len + + self.agent_action_context_len, :] + context_ins = context[:, self.agent_state_context_len + + self.agent_action_context_len:self. + agent_state_context_len + + self.agent_action_context_len + + self.text_context_len, :] + context_image = context[:, self.agent_state_context_len + + self.agent_action_context_len + + self.text_context_len:, :] + + k = self.to_k(context_ins) + v = self.to_v(context_ins) + k_ip = self.to_k_ip(context_image) + v_ip = self.to_v_ip(context_image) + k_as = self.to_k_as(context_agent_state) + v_as = self.to_v_as(context_agent_state) + k_aa = self.to_k_aa(context_agent_action) + v_aa = self.to_v_aa(context_agent_action) + else: + if not spatial_self_attn: + context = context[:, :self.text_context_len, :] + k = self.to_k(context) + v = self.to_v(context) + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), + (q, k, v)) + + # baddbmm: fuse scale into GEMM → one kernel instead of matmul + mul + sim = torch.baddbmm( + torch.empty(q.shape[0], q.shape[1], k.shape[1], dtype=q.dtype, device=q.device), + q, k.transpose(-1, -2), beta=0, alpha=self.scale) + del k + + if exists(mask): + max_neg_value = -torch.finfo(sim.dtype).max + mask = repeat(mask, 'b i j -> (b h) i j', h=h) + sim.masked_fill_(~(mask > 0.5), max_neg_value) + + with torch.amp.autocast('cuda', enabled=False): + sim = sim.softmax(dim=-1) + + out = torch.bmm(sim, v) + out = rearrange(out, '(b h) n d -> b n (h d)', h=h) + + if k_ip is not None and k_as is not None and k_aa is not None: + ## image cross-attention + k_ip, v_ip = map( + lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), + (k_ip, v_ip)) + sim_ip = torch.baddbmm( + torch.empty(q.shape[0], q.shape[1], k_ip.shape[1], dtype=q.dtype, device=q.device), + q, k_ip.transpose(-1, -2), beta=0, alpha=self.scale) + del k_ip + with torch.amp.autocast('cuda', enabled=False): + sim_ip = sim_ip.softmax(dim=-1) + out_ip = torch.bmm(sim_ip, v_ip) + out_ip = rearrange(out_ip, '(b h) n d -> b n (h d)', h=h) + + ## agent state cross-attention + k_as, v_as = map( + lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), + (k_as, v_as)) + sim_as = torch.baddbmm( + torch.empty(q.shape[0], q.shape[1], k_as.shape[1], dtype=q.dtype, device=q.device), + q, k_as.transpose(-1, -2), beta=0, alpha=self.scale) + del k_as + with torch.amp.autocast('cuda', enabled=False): + sim_as = sim_as.softmax(dim=-1) + out_as = torch.bmm(sim_as, v_as) + out_as = rearrange(out_as, '(b h) n d -> b n (h d)', h=h) + + ## agent action cross-attention + k_aa, v_aa = map( + lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), + (k_aa, v_aa)) + sim_aa = torch.baddbmm( + torch.empty(q.shape[0], q.shape[1], k_aa.shape[1], dtype=q.dtype, device=q.device), + q, k_aa.transpose(-1, -2), beta=0, alpha=self.scale) + del k_aa + with torch.amp.autocast('cuda', enabled=False): + sim_aa = sim_aa.softmax(dim=-1) + out_aa = torch.bmm(sim_aa, v_aa) + out_aa = rearrange(out_aa, '(b h) n d -> b n (h d)', h=h) + + if out_ip is not None and out_as is not None and out_aa is not None: + if self.cross_attention_scale_learnable: + out = out + \ + self.image_cross_attention_scale * out_ip * (torch.tanh(self.alpha_ctx) + 1) + \ + self.agent_state_cross_attention_scale * out_as * (torch.tanh(self.alpha_cas) + 1) + \ + self.agent_action_cross_attention_scale * out_aa * (torch.tanh(self.alpha_caa) + 1) + else: + out = out + \ + self.image_cross_attention_scale * out_ip + \ + self.agent_state_cross_attention_scale * out_as + \ + self.agent_action_cross_attention_scale * out_aa + + return self.to_out(out) + def efficient_forward(self, x, context=None, mask=None): spatial_self_attn = (context is None) k, v, out = None, None, None diff --git a/unitree_g1_pack_camera/case2/output.log b/unitree_g1_pack_camera/case2/output.log index 6ad70b6..a887ea1 100644 --- a/unitree_g1_pack_camera/case2/output.log +++ b/unitree_g1_pack_camera/case2/output.log @@ -1,14 +1,16 @@ -2026-02-08 05:06:45.806187: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`. -2026-02-08 05:06:45.809295: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used. -2026-02-08 05:06:45.840950: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered -2026-02-08 05:06:45.840981: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered -2026-02-08 05:06:45.842814: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered -2026-02-08 05:06:45.851049: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used. -2026-02-08 05:06:45.851316: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations. +/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81. + __import__("pkg_resources").declare_namespace(__name__) +2026-02-08 18:28:48.960238: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`. +2026-02-08 18:28:48.963331: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used. +2026-02-08 18:28:48.995688: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered +2026-02-08 18:28:48.995732: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered +2026-02-08 18:28:48.997547: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered +2026-02-08 18:28:49.005673: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used. +2026-02-08 18:28:49.005948: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations. To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags. -2026-02-08 05:06:47.225477: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT +2026-02-08 18:28:50.009660: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT [rank: 0] Global seed set to 123 -/mnt/ASC1637/miniconda3/envs/unifolm-wma/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead. +/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead. @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32) INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08 @@ -18,15 +20,27 @@ INFO:root:Loaded ViT-H-14 model config. DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443 DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0 INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k). -/mnt/ASC1637/miniconda3/envs/unifolm-wma/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. +/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. checkpoint = torch.load(checkpoint_path, map_location=map_location) INFO:root:Loaded ViT-H-14 model config. DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0 INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k). -/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:86: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. +/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:198: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. state_dict = torch.load(ckpt, map_location="cpu") >>> model checkpoint loaded. >>> Load pre-trained model ... +>>> Applying precision settings: + - Diffusion dtype: bf16 + - Projector mode: bf16_full + - Encoder mode: bf16_full + - VAE dtype: fp32 + ✓ Diffusion model weights converted to bfloat16 + ✓ Projectors converted to bfloat16 + ✓ Encoders converted to bfloat16 + ✓ VAE kept in fp32 for best quality + ⚠ Found 849 fp32 params, converting to bf16 + ✓ All parameters converted to bfloat16 + ✓ torch.compile: 3 ResBlocks in output_blocks[5, 8, 9] INFO:root:***** Configing Data ***** >>> unitree_z1_stackbox: 1 data samples loaded. >>> unitree_z1_stackbox: data stats loaded. @@ -49,11 +63,11 @@ DEBUG:h5py._conv:Creating converter from 3 to 5 DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13 DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9 DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096 - 0%| | 0/11 [00:00>> Step 0: generating actions ... >>> Step 0: interacting with world model ... @@ -106,7 +120,7 @@ DEBUG:PIL.Image:Importing WmfImagePlugin DEBUG:PIL.Image:Importing XbmImagePlugin DEBUG:PIL.Image:Importing XpmImagePlugin DEBUG:PIL.Image:Importing XVThumbImagePlugin - 9%|▉ | 1/11 [01:37<16:14, 97.41s/it] 18%|█▊ | 2/11 [03:14<14:35, 97.22s/it] 27%|██▋ | 3/11 [04:51<12:58, 97.33s/it] 36%|███▋ | 4/11 [06:29<11:22, 97.47s/it] 45%|████▌ | 5/11 [08:07<09:45, 97.57s/it] 55%|█████▍ | 6/11 [09:45<08:07, 97.59s/it] 64%|██████▎ | 7/11 [11:22<06:30, 97.57s/it] 73%|███████▎ | 8/11 [13:00<04:52, 97.54s/it] 82%|████████▏ | 9/11 [14:37<03:14, 97.50s/it] 91%|█████████ | 10/11 [16:14<01:37, 97.32s/it] 100%|██████████| 11/11 [17:51<00:00, 97.19s/it] 100%|██████████| 11/11 [17:51<00:00, 97.39s/it] + 9%|▉ | 1/11 [01:14<12:29, 74.95s/it] 18%|█▊ | 2/11 [02:23<10:40, 71.18s/it] 27%|██▋ | 3/11 [03:32<09:20, 70.05s/it] 36%|███▋ | 4/11 [04:40<08:06, 69.51s/it] 45%|████▌ | 5/11 [05:49<06:55, 69.19s/it] 55%|█████▍ | 6/11 [06:57<05:44, 68.95s/it] 64%|██████▎ | 7/11 [08:06<04:35, 68.79s/it] 73%|███████▎ | 8/11 [09:14<03:26, 68.70s/it] 82%|████████▏ | 9/11 [10:23<02:17, 68.65s/it] 91%|█████████ | 10/11 [11:31<01:08, 68.58s/it] 100%|██████████| 11/11 [12:40<00:00, 68.51s/it] 100%|██████████| 11/11 [12:40<00:00, 69.11s/it] >>>>>>>>>>>>>>>>>>>>>>>> >>> Step 1: generating actions ... >>> Step 1: interacting with world model ... @@ -139,6 +153,6 @@ DEBUG:PIL.Image:Importing XVThumbImagePlugin >>> Step 10: interacting with world model ... >>>>>>>>>>>>>>>>>>>>>>>> -real 19m5.537s -user 16m51.114s -sys 0m52.978s +real 13m47.911s +user 15m20.343s +sys 0m58.417s diff --git a/unitree_z1_dual_arm_cleanup_pencils/case1/output.log b/unitree_z1_dual_arm_cleanup_pencils/case1/output.log index e9417c8..2bb97d7 100644 --- a/unitree_z1_dual_arm_cleanup_pencils/case1/output.log +++ b/unitree_z1_dual_arm_cleanup_pencils/case1/output.log @@ -1,14 +1,16 @@ -2026-02-08 16:49:41.598605: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`. -2026-02-08 16:49:41.601687: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used. -2026-02-08 16:49:41.632954: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered -2026-02-08 16:49:41.632986: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered -2026-02-08 16:49:41.634849: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered -2026-02-08 16:49:41.643134: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used. -2026-02-08 16:49:41.643414: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations. +/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81. + __import__("pkg_resources").declare_namespace(__name__) +2026-02-08 18:43:46.463492: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`. +2026-02-08 18:43:46.466714: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used. +2026-02-08 18:43:46.498994: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered +2026-02-08 18:43:46.499029: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered +2026-02-08 18:43:46.500865: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered +2026-02-08 18:43:46.509069: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used. +2026-02-08 18:43:46.509359: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations. To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags. -2026-02-08 16:49:42.320864: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT +2026-02-08 18:43:47.434136: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT [rank: 0] Global seed set to 123 -/mnt/ASC1637/miniconda3/envs/unifolm-wma/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead. +/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead. @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32) INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08 @@ -18,7 +20,7 @@ INFO:root:Loaded ViT-H-14 model config. DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443 DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0 INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k). -/mnt/ASC1637/miniconda3/envs/unifolm-wma/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. +/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. checkpoint = torch.load(checkpoint_path, map_location=map_location) INFO:root:Loaded ViT-H-14 model config. DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0 @@ -61,7 +63,7 @@ DEBUG:h5py._conv:Creating converter from 3 to 5 DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13 DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9 DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096 - 0%| | 0/8 [00:00>> Step 0: generating actions ... >>> Step 0: interacting with world model ... @@ -114,7 +116,7 @@ DEBUG:PIL.Image:Importing WmfImagePlugin DEBUG:PIL.Image:Importing XbmImagePlugin DEBUG:PIL.Image:Importing XpmImagePlugin DEBUG:PIL.Image:Importing XVThumbImagePlugin - 12%|█▎ | 1/8 [01:15<08:45, 75.10s/it] 25%|██▌ | 2/8 [02:26<07:17, 72.96s/it] 38%|███▊ | 3/8 [03:38<06:01, 72.27s/it] 50%|█████ | 4/8 [04:49<04:48, 72.00s/it] 62%|██████▎ | 5/8 [06:01<03:35, 71.97s/it] 75%|███████▌ | 6/8 [07:12<02:23, 71.77s/it] 88%|████████▊ | 7/8 [08:24<01:11, 71.56s/it] 100%|██████████| 8/8 [09:35<00:00, 71.59s/it] 100%|██████████| 8/8 [09:35<00:00, 71.96s/it] + 12%|█▎ | 1/8 [01:12<08:27, 72.57s/it] 25%|██▌ | 2/8 [02:21<07:02, 70.44s/it] 38%|███▊ | 3/8 [03:30<05:48, 69.76s/it] 50%|█████ | 4/8 [04:39<04:37, 69.48s/it] 62%|██████▎ | 5/8 [05:48<03:27, 69.31s/it] 75%|███████▌ | 6/8 [06:57<02:18, 69.19s/it] 88%|████████▊ | 7/8 [08:06<01:09, 69.04s/it] 100%|██████████| 8/8 [09:15<00:00, 69.05s/it] 100%|██████████| 8/8 [09:15<00:00, 69.41s/it] >>>>>>>>>>>>>>>>>>>>>>>> >>> Step 1: generating actions ... >>> Step 1: interacting with world model ... @@ -138,6 +140,6 @@ DEBUG:PIL.Image:Importing XVThumbImagePlugin >>> Step 7: interacting with world model ... >>>>>>>>>>>>>>>>>>>>>>>> -real 10m35.511s -user 12m11.689s -sys 0m40.191s +real 10m17.951s +user 11m44.955s +sys 0m40.480s