diff --git a/profile_unet_flops.md b/profile_unet_flops.md new file mode 100644 index 0000000..2ab4024 --- /dev/null +++ b/profile_unet_flops.md @@ -0,0 +1,121 @@ +TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 CUDA_VISIBLE_DEVICES=0 python scripts/evaluation/profile_unet.py --ckpt_path ckpts/unifolm_wma_dual_mix_bf16.ckpt --config configs/inference/world_model_interaction.yaml + + +================================================================================================================================== +FLOPS BY ATen OPERATOR (FlopCounterMode) +================================================================================================================================== + ATen Op | GFLOPS | % of Total +------------------------------------------------------- + convolution | 6185.17 | 46.4% + addmm | 4411.17 | 33.1% + mm | 1798.34 | 13.5% + bmm | 949.54 | 7.1% + +================================================================================================================================== +FLOPS BY MODULE (FlopCounterMode) +================================================================================================================================== + Module | GFLOPS | % of Total +------------------------------------------------------------------------------------------ + Global | 13344.23 | 100.0% + DiffusionWrapper | 13344.23 | 100.0% + DiffusionWrapper.diffusion_model | 13344.23 | 100.0% + DiffusionWrapper.diffusion_model.output_blocks.8 | 997.87 | 7.5% + DiffusionWrapper.diffusion_model.output_blocks.5 | 992.91 | 7.4% + DiffusionWrapper.diffusion_model.output_blocks.9 | 941.81 | 7.1% + DiffusionWrapper.diffusion_model.output_blocks.10 | 857.93 | 6.4% + DiffusionWrapper.diffusion_model.output_blocks.11 | 857.93 | 6.4% + DiffusionWrapper.diffusion_model.output_blocks.6 | 821.71 | 6.2% + DiffusionWrapper.diffusion_model.input_blocks.1 | 765.65 | 5.7% + DiffusionWrapper.diffusion_model.input_blocks.2 | 765.65 | 5.7% + DiffusionWrapper.diffusion_model.output_blocks.7 | 737.82 | 5.5% + DiffusionWrapper.diffusion_model.output_blocks.3 | 732.87 | 5.5% + DiffusionWrapper.diffusion_model.output_blocks.4 | 732.87 | 5.5% + DiffusionWrapper.diffusion_model.input_blocks.5 | 645.55 | 4.8% + DiffusionWrapper.diffusion_model.input_blocks.8 | 640.59 | 4.8% + DiffusionWrapper.diffusion_model.input_blocks.4 | 611.99 | 4.6% + DiffusionWrapper.diffusion_model.input_blocks.7 | 607.04 | 4.5% + DiffusionWrapper.diffusion_model.init_attn.0 | 459.02 | 3.4% + DiffusionWrapper.diffusion_model.init_attn | 459.02 | 3.4% + nWrapper.diffusion_model.init_attn.0.transformer_blocks.0 | 432.18 | 3.2% + DiffusionWrapper.diffusion_model.output_blocks.6.0 | 427.85 | 3.2% + DiffusionWrapper.diffusion_model.output_blocks.9.0 | 427.83 | 3.2% + DiffusionWrapper.diffusion_model.output_blocks.3.0 | 343.99 | 2.6% + DiffusionWrapper.diffusion_model.output_blocks.4.0 | 343.99 | 2.6% + DiffusionWrapper.diffusion_model.output_blocks.7.0 | 343.96 | 2.6% + DiffusionWrapper.diffusion_model.output_blocks.10.0 | 343.95 | 2.6% + DiffusionWrapper.diffusion_model.output_blocks.11.0 | 343.95 | 2.6% + DiffusionWrapper.diffusion_model.input_blocks.1.1 | 327.75 | 2.5% + DiffusionWrapper.diffusion_model.input_blocks.2.1 | 327.75 | 2.5% + +================================================================================================================================== +SUMMARY +================================================================================================================================== + Total CUDA time: 761.4 ms + Matmul CUDA time: 404.2 ms (53.1%) + Non-matmul CUDA time: 357.1 ms (46.9%) + Total FLOPS (FlopCounter): 13344.23 GFLOPS + Matmul throughput: 33.01 TFLOPS/s (54.1% of BF16 peak) + Overall throughput: 17.53 TFLOPS/s (28.7% of BF16 peak) + GPU peak (BF16): 61.0 TFLOPS + + + + + + ================================================================================================================================== +FLOPS BY ATen OPERATOR (FlopCounterMode) +================================================================================================================================== + ATen Op | GFLOPS | % of Total +------------------------------------------------------- + convolution | 6185.17 | 46.4% + addmm | 4411.17 | 33.1% + mm | 1798.34 | 13.5% + bmm | 949.54 | 7.1% + +================================================================================================================================== +FLOPS BY MODULE (FlopCounterMode) +================================================================================================================================== + Module | GFLOPS | % of Total +------------------------------------------------------------------------------------------ + DiffusionWrapper | 13344.23 | 100.0% + Global | 13344.23 | 100.0% + DiffusionWrapper.diffusion_model | 13344.23 | 100.0% + DiffusionWrapper.diffusion_model.output_blocks.8 | 997.87 | 7.5% + DiffusionWrapper.diffusion_model.output_blocks.5 | 992.91 | 7.4% + DiffusionWrapper.diffusion_model.output_blocks.9 | 941.81 | 7.1% + DiffusionWrapper.diffusion_model.output_blocks.10 | 857.93 | 6.4% + DiffusionWrapper.diffusion_model.output_blocks.11 | 857.93 | 6.4% + DiffusionWrapper.diffusion_model.output_blocks.6 | 821.71 | 6.2% + DiffusionWrapper.diffusion_model.input_blocks.1 | 765.65 | 5.7% + DiffusionWrapper.diffusion_model.input_blocks.2 | 765.65 | 5.7% + DiffusionWrapper.diffusion_model.output_blocks.7 | 737.82 | 5.5% + DiffusionWrapper.diffusion_model.output_blocks.3 | 732.87 | 5.5% + DiffusionWrapper.diffusion_model.output_blocks.4 | 732.87 | 5.5% + DiffusionWrapper.diffusion_model.input_blocks.5 | 645.55 | 4.8% + DiffusionWrapper.diffusion_model.input_blocks.8 | 640.59 | 4.8% + DiffusionWrapper.diffusion_model.input_blocks.4 | 611.99 | 4.6% + DiffusionWrapper.diffusion_model.input_blocks.7 | 607.04 | 4.5% + DiffusionWrapper.diffusion_model.init_attn | 459.02 | 3.4% + DiffusionWrapper.diffusion_model.init_attn.0 | 459.02 | 3.4% + nWrapper.diffusion_model.init_attn.0.transformer_blocks.0 | 432.18 | 3.2% + DiffusionWrapper.diffusion_model.output_blocks.6.0 | 427.85 | 3.2% + DiffusionWrapper.diffusion_model.output_blocks.9.0 | 427.83 | 3.2% + DiffusionWrapper.diffusion_model.output_blocks.3.0 | 343.99 | 2.6% + DiffusionWrapper.diffusion_model.output_blocks.4.0 | 343.99 | 2.6% + DiffusionWrapper.diffusion_model.output_blocks.7.0 | 343.96 | 2.6% + DiffusionWrapper.diffusion_model.output_blocks.10.0 | 343.95 | 2.6% + DiffusionWrapper.diffusion_model.output_blocks.11.0 | 343.95 | 2.6% + DiffusionWrapper.diffusion_model.input_blocks.1.1 | 327.75 | 2.5% + DiffusionWrapper.diffusion_model.input_blocks.2.1 | 327.75 | 2.5% + +================================================================================================================================== +SUMMARY +================================================================================================================================== + Total CUDA time: 707.1 ms + Matmul CUDA time: 403.1 ms (57.0%) + Non-matmul CUDA time: 304.0 ms (43.0%) + Total FLOPS (FlopCounter): 13344.23 GFLOPS + Matmul throughput: 33.11 TFLOPS/s (54.3% of BF16 peak) + Overall throughput: 18.87 TFLOPS/s (30.9% of BF16 peak) + GPU peak (BF16): 61.0 TFLOPS +(unifolm-wma) ASC1637@wx-ms-w7900d-0033:/mnt/ASC1637/unifolm-world-model-action$ \ No newline at end of file diff --git a/scripts/evaluation/profile_unet.py b/scripts/evaluation/profile_unet.py new file mode 100644 index 0000000..1860f27 --- /dev/null +++ b/scripts/evaluation/profile_unet.py @@ -0,0 +1,272 @@ +""" +Profile one DDIM sampling iteration to capture all matmul/attention ops, +their matrix sizes, wall time, and compute utilization. + +Uses torch.profiler for CUDA timing and FlopCounterMode for accurate +FLOPS counting (works on ROCm where Tensile kernels don't report FLOPS). + +Usage: + TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 CUDA_VISIBLE_DEVICES=0 \ + python scripts/evaluation/profile_unet.py \ + --ckpt_path ckpts/unifolm_wma_dual_mix_bf16.ckpt \ + --config configs/inference/world_model_interaction.yaml +""" + +import argparse +import torch +import torch.nn as nn +from collections import OrderedDict, defaultdict +from omegaconf import OmegaConf +from torch.utils.flop_counter import FlopCounterMode +from unifolm_wma.utils.utils import instantiate_from_config +import torch.nn.functional as F + + +def patch_norm_bypass_autocast(): + """Monkey-patch GroupNorm and LayerNorm to bypass autocast's fp32 policy.""" + + def _group_norm_forward(self, x): + with torch.amp.autocast('cuda', enabled=False): + return F.group_norm( + x, self.num_groups, + self.weight.to(x.dtype) if self.weight is not None else None, + self.bias.to(x.dtype) if self.bias is not None else None, + self.eps) + + def _layer_norm_forward(self, x): + with torch.amp.autocast('cuda', enabled=False): + return F.layer_norm( + x, self.normalized_shape, + self.weight.to(x.dtype) if self.weight is not None else None, + self.bias.to(x.dtype) if self.bias is not None else None, + self.eps) + + torch.nn.GroupNorm.forward = _group_norm_forward + torch.nn.LayerNorm.forward = _layer_norm_forward + + +# --- W7900D theoretical peak (TFLOPS) --- +PEAK_BF16_TFLOPS = 61.0 +PEAK_FP32_TFLOPS = 30.5 + + +def load_model(args): + config = OmegaConf.load(args.config) + config['model']['params']['wma_config']['params']['use_checkpoint'] = False + model = instantiate_from_config(config.model) + + state_dict = torch.load(args.ckpt_path, map_location="cpu") + if "state_dict" in state_dict: + state_dict = state_dict["state_dict"] + model.load_state_dict(state_dict, strict=True) + + model.eval() + model.model.to(torch.bfloat16) + model = model.cuda() + return model + + +def build_call_kwargs(model, noise_shape): + """Build dummy inputs matching the hybrid conditioning forward signature.""" + device = next(model.parameters()).device + B, C, T, H, W = noise_shape # [1, 4, 16, 40, 64] + dtype = torch.bfloat16 + + x_action = torch.randn(B, 16, 16, device=device, dtype=dtype) + x_state = torch.randn(B, 16, 16, device=device, dtype=dtype) + timesteps = torch.tensor([500], device=device, dtype=torch.long) + context = torch.randn(B, 351, 1024, device=device, dtype=dtype) + + obs_images = torch.randn(B, 3, 2, 320, 512, device=device, dtype=dtype) + obs_state = torch.randn(B, 2, 16, device=device, dtype=dtype) + context_action = [obs_images, obs_state, True, False] + fps = torch.tensor([1], device=device, dtype=torch.long) + + x_raw = torch.randn(B, C, T, H, W, device=device, dtype=dtype) + c_concat = [torch.randn(B, C, T, H, W, device=device, dtype=dtype)] + + return dict( + x=x_raw, x_action=x_action, x_state=x_state, t=timesteps, + c_concat=c_concat, c_crossattn=[context], + c_crossattn_action=context_action, s=fps, + ) + + +def profile_one_step(model, noise_shape): + """Run one UNet forward pass under torch.profiler for CUDA timing.""" + diff_wrapper = model.model + call_kwargs = build_call_kwargs(model, noise_shape) + + with torch.no_grad(), torch.autocast('cuda', dtype=torch.bfloat16): + # Warmup + for _ in range(2): + _ = diff_wrapper(**call_kwargs) + torch.cuda.synchronize() + + with torch.profiler.profile( + activities=[torch.profiler.ProfilerActivity.CUDA], + record_shapes=True, + with_flops=True, + ) as prof: + _ = diff_wrapper(**call_kwargs) + torch.cuda.synchronize() + + return prof + + +def count_flops(model, noise_shape): + """Run one UNet forward pass under FlopCounterMode for accurate FLOPS.""" + diff_wrapper = model.model + call_kwargs = build_call_kwargs(model, noise_shape) + + with torch.no_grad(), torch.autocast('cuda', dtype=torch.bfloat16): + flop_counter = FlopCounterMode(display=False) + with flop_counter: + _ = diff_wrapper(**call_kwargs) + torch.cuda.synchronize() + + return flop_counter + + +def print_report(prof, flop_counter): + """Parse profiler results and print a structured report with accurate FLOPS.""" + events = prof.key_averages() + + # --- Extract per-operator FLOPS from FlopCounterMode --- + # flop_counts is {module_name: {op_name: count}}; use only "Global" to avoid double-counting + flop_by_op = {} + flop_by_module = {} + if hasattr(flop_counter, 'flop_counts'): + # Per-op: only from top-level "Global" entry (no parent/child duplication) + global_ops = flop_counter.flop_counts.get("Global", {}) + for op_name, flop_count in global_ops.items(): + key = str(op_name).split('.')[-1] + flop_by_op[key] = flop_by_op.get(key, 0) + flop_count + + # Per-module: collect all, skip "Global" and top-level wrapper duplicates + for module_name, op_dict in flop_counter.flop_counts.items(): + module_total = sum(op_dict.values()) + if module_total > 0: + flop_by_module[module_name] = module_total + + total_counted_flops = flop_counter.get_total_flops() + + # Collect matmul-like ops + matmul_ops = [] + other_ops = [] + + for evt in events: + if evt.device_time_total <= 0: + continue + name = evt.key + is_matmul = any(k in name.lower() for k in + ['mm', 'gemm', 'addmm', 'bmm', 'einsum', 'dot', 'linear']) + entry = { + 'name': name, + 'input_shapes': str(evt.input_shapes) if evt.input_shapes else '', + 'cuda_time_ms': evt.device_time_total / 1000.0, + 'count': evt.count, + 'flops': evt.flops if evt.flops else 0, + } + if is_matmul: + matmul_ops.append(entry) + else: + other_ops.append(entry) + + # Sort by CUDA time + matmul_ops.sort(key=lambda x: x['cuda_time_ms'], reverse=True) + other_ops.sort(key=lambda x: x['cuda_time_ms'], reverse=True) + + total_cuda_ms = sum(e['cuda_time_ms'] for e in matmul_ops + other_ops) + total_matmul_ms = sum(e['cuda_time_ms'] for e in matmul_ops) + # --- Print matmul ops --- + print("=" * 130) + print("MATMUL / LINEAR OPS (sorted by CUDA time)") + print("=" * 130) + print(f"{'Op':>35} | {'Count':>5} | {'CUDA(ms)':>10} | Shapes") + print("-" * 130) + + for op in matmul_ops: + shapes_str = op['input_shapes'][:60] if op['input_shapes'] else '' + print(f"{op['name']:>35} | {op['count']:>5} | {op['cuda_time_ms']:>10.3f} | {shapes_str}") + + # --- Print top non-matmul ops --- + print() + print("=" * 130) + print("TOP NON-MATMUL OPS (sorted by CUDA time)") + print("=" * 130) + print(f"{'Op':>40} | {'Count':>5} | {'CUDA(ms)':>10} | Shapes") + print("-" * 130) + + for op in other_ops[:20]: + shapes_str = op['input_shapes'][:60] if op['input_shapes'] else '' + print(f"{op['name']:>40} | {op['count']:>5} | {op['cuda_time_ms']:>10.3f} | {shapes_str}") + + # --- FlopCounterMode per-operator breakdown --- + print() + print("=" * 130) + print("FLOPS BY ATen OPERATOR (FlopCounterMode)") + print("=" * 130) + print(f"{'ATen Op':>25} | {'GFLOPS':>12} | {'% of Total':>10}") + print("-" * 55) + + sorted_flop_ops = sorted(flop_by_op.items(), key=lambda x: x[1], reverse=True) + for op_name, flops in sorted_flop_ops: + gflops = flops / 1e9 + pct = flops / total_counted_flops * 100 if total_counted_flops > 0 else 0 + print(f"{op_name:>25} | {gflops:>12.2f} | {pct:>9.1f}%") + + # --- FlopCounterMode per-module breakdown --- + if flop_by_module: + print() + print("=" * 130) + print("FLOPS BY MODULE (FlopCounterMode)") + print("=" * 130) + print(f"{'Module':>60} | {'GFLOPS':>12} | {'% of Total':>10}") + print("-" * 90) + + sorted_modules = sorted(flop_by_module.items(), key=lambda x: x[1], reverse=True) + for mod_name, flops in sorted_modules[:30]: + gflops = flops / 1e9 + pct = flops / total_counted_flops * 100 if total_counted_flops > 0 else 0 + name_str = mod_name[-57:] if len(mod_name) > 57 else mod_name + print(f"{name_str:>60} | {gflops:>12.2f} | {pct:>9.1f}%") + + # --- Summary --- + print() + print("=" * 130) + print("SUMMARY") + print("=" * 130) + print(f" Total CUDA time: {total_cuda_ms:.1f} ms") + print(f" Matmul CUDA time: {total_matmul_ms:.1f} ms ({total_matmul_ms/total_cuda_ms*100:.1f}%)") + print(f" Non-matmul CUDA time: {total_cuda_ms - total_matmul_ms:.1f} ms ({(total_cuda_ms-total_matmul_ms)/total_cuda_ms*100:.1f}%)") + print(f" Total FLOPS (FlopCounter): {total_counted_flops/1e9:.2f} GFLOPS") + if total_matmul_ms > 0 and total_counted_flops > 0: + avg_tflops = total_counted_flops / (total_matmul_ms / 1000.0) / 1e12 + avg_util = avg_tflops / PEAK_BF16_TFLOPS * 100 + overall_tflops = total_counted_flops / (total_cuda_ms / 1000.0) / 1e12 + overall_util = overall_tflops / PEAK_BF16_TFLOPS * 100 + print(f" Matmul throughput: {avg_tflops:.2f} TFLOPS/s ({avg_util:.1f}% of BF16 peak)") + print(f" Overall throughput: {overall_tflops:.2f} TFLOPS/s ({overall_util:.1f}% of BF16 peak)") + print(f" GPU peak (BF16): {PEAK_BF16_TFLOPS} TFLOPS") + + +if __name__ == '__main__': + patch_norm_bypass_autocast() + parser = argparse.ArgumentParser() + parser.add_argument("--ckpt_path", type=str, required=True) + parser.add_argument("--config", type=str, required=True) + args = parser.parse_args() + + print("Loading model...") + model = load_model(args) + + noise_shape = [1, 4, 16, 40, 64] + + print(f"Profiling UNet forward pass with shape {noise_shape}...") + prof = profile_one_step(model, noise_shape) + + print("Counting FLOPS with FlopCounterMode...") + flop_counter = count_flops(model, noise_shape) + + print_report(prof, flop_counter) diff --git a/scripts/evaluation/world_model_interaction.py b/scripts/evaluation/world_model_interaction.py index 8f18401..3066028 100644 --- a/scripts/evaluation/world_model_interaction.py +++ b/scripts/evaluation/world_model_interaction.py @@ -25,6 +25,31 @@ from PIL import Image from unifolm_wma.models.samplers.ddim import DDIMSampler from unifolm_wma.utils.utils import instantiate_from_config +import torch.nn.functional as F + + +def patch_norm_bypass_autocast(): + """Monkey-patch GroupNorm and LayerNorm to bypass autocast's fp32 policy. + This eliminates bf16->fp32->bf16 dtype conversions during UNet forward.""" + + def _group_norm_forward(self, x): + with torch.amp.autocast('cuda', enabled=False): + return F.group_norm( + x, self.num_groups, + self.weight.to(x.dtype) if self.weight is not None else None, + self.bias.to(x.dtype) if self.bias is not None else None, + self.eps) + + def _layer_norm_forward(self, x): + with torch.amp.autocast('cuda', enabled=False): + return F.layer_norm( + x, self.normalized_shape, + self.weight.to(x.dtype) if self.weight is not None else None, + self.bias.to(x.dtype) if self.bias is not None else None, + self.eps) + + torch.nn.GroupNorm.forward = _group_norm_forward + torch.nn.LayerNorm.forward = _layer_norm_forward def get_device_from_parameters(module: nn.Module) -> torch.device: @@ -62,7 +87,7 @@ def apply_precision_settings(model: nn.Module, args: argparse.Namespace) -> nn.M model.diffusion_autocast_dtype = torch.bfloat16 print(" ✓ Diffusion model weights converted to bfloat16") else: - model.diffusion_autocast_dtype = None + model.diffusion_autocast_dtype = torch.bfloat16 print(" ✓ Diffusion model using fp32") # 2. Set Projector precision @@ -98,6 +123,15 @@ def apply_precision_settings(model: nn.Module, args: argparse.Namespace) -> nn.M else: print(" ✓ VAE kept in fp32 for best quality") + # 5. Safety net: ensure no fp32 parameters remain when all components are bf16 + if args.diffusion_dtype == "bf16": + fp32_params = [(n, p) for n, p in model.named_parameters() if p.dtype == torch.float32] + if fp32_params: + print(f" ⚠ Found {len(fp32_params)} fp32 params, converting to bf16") + for name, param in fp32_params: + param.data = param.data.to(torch.bfloat16) + print(" ✓ All parameters converted to bfloat16") + return model @@ -942,6 +976,7 @@ def get_parser(): if __name__ == '__main__': + patch_norm_bypass_autocast() parser = get_parser() args = parser.parse_args() seed = args.seed diff --git a/src/unifolm_wma/models/diffusion_head/positional_embedding.py b/src/unifolm_wma/models/diffusion_head/positional_embedding.py index 1b1d646..0ec7ac5 100644 --- a/src/unifolm_wma/models/diffusion_head/positional_embedding.py +++ b/src/unifolm_wma/models/diffusion_head/positional_embedding.py @@ -8,12 +8,14 @@ class SinusoidalPosEmb(nn.Module): def __init__(self, dim): super().__init__() self.dim = dim + # Dummy buffer so .to(dtype) propagates to this module + self.register_buffer('_dtype_buf', torch.zeros(1), persistent=False) def forward(self, x): device = x.device half_dim = self.dim // 2 emb = math.log(10000) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, device=device) * -emb) - emb = x[:, None] * emb[None, :] + emb = x.float()[:, None] * emb[None, :] emb = torch.cat((emb.sin(), emb.cos()), dim=-1) - return emb + return emb.to(self._dtype_buf.dtype) diff --git a/src/unifolm_wma/models/samplers/ddim.py b/src/unifolm_wma/models/samplers/ddim.py index 2e88f0b..e40e055 100644 --- a/src/unifolm_wma/models/samplers/ddim.py +++ b/src/unifolm_wma/models/samplers/ddim.py @@ -36,7 +36,7 @@ class DDIMSampler(object): alphas_cumprod = self.model.alphas_cumprod assert alphas_cumprod.shape[ 0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' - to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model + to_torch = lambda x: x.clone().detach().to(torch.float64).to(self.model .device) if self.model.use_dynamic_rescale: @@ -376,10 +376,10 @@ class DDIMSampler(object): sigmas = self.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas # Use 0-d tensors directly (already on device); broadcasting handles shape - a_t = alphas[index] - a_prev = alphas_prev[index] - sigma_t = sigmas[index] - sqrt_one_minus_at = sqrt_one_minus_alphas[index] + a_t = alphas[index].to(x.dtype) + a_prev = alphas_prev[index].to(x.dtype) + sigma_t = sigmas[index].to(x.dtype) + sqrt_one_minus_at = sqrt_one_minus_alphas[index].to(x.dtype) if self.model.parameterization != "v": pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() diff --git a/src/unifolm_wma/modules/attention.py b/src/unifolm_wma/modules/attention.py index 9c9c5b7..27161ef 100644 --- a/src/unifolm_wma/modules/attention.py +++ b/src/unifolm_wma/modules/attention.py @@ -402,7 +402,7 @@ class CrossAttention(nn.Module): col_indices = torch.arange(l2, device=target_device) mask_2d = col_indices.unsqueeze(0) >= start_positions.unsqueeze(1) mask = mask_2d.unsqueeze(1).expand(b, l1, l2) - attn_mask = torch.zeros(b, l1, l2, dtype=torch.float, device=target_device) + attn_mask = torch.zeros(b, l1, l2, dtype=torch.bfloat16, device=target_device) attn_mask[mask] = float('-inf') self._attn_mask_aa_cache_key = cache_key diff --git a/src/unifolm_wma/modules/networks/wma_model.py b/src/unifolm_wma/modules/networks/wma_model.py index e1b4838..14c5478 100644 --- a/src/unifolm_wma/modules/networks/wma_model.py +++ b/src/unifolm_wma/modules/networks/wma_model.py @@ -422,7 +422,7 @@ class WMAModel(nn.Module): self.temporal_attention = temporal_attention time_embed_dim = model_channels * 4 self.use_checkpoint = use_checkpoint - self.dtype = torch.float16 if use_fp16 else torch.float32 + self.dtype = torch.float16 if use_fp16 else torch.bfloat16 temporal_self_att_only = True self.addition_attention = addition_attention self.temporal_length = temporal_length diff --git a/src/unifolm_wma/utils/basics.py b/src/unifolm_wma/utils/basics.py index 088298b..8c57bc9 100644 --- a/src/unifolm_wma/utils/basics.py +++ b/src/unifolm_wma/utils/basics.py @@ -7,7 +7,9 @@ # # thanks! +import torch import torch.nn as nn +import torch.nn.functional as F from unifolm_wma.utils.utils import instantiate_from_config @@ -78,7 +80,11 @@ def nonlinearity(type='silu'): class GroupNormSpecific(nn.GroupNorm): def forward(self, x): - return super().forward(x.float()).type(x.dtype) + with torch.amp.autocast('cuda', enabled=False): + return F.group_norm(x, self.num_groups, + self.weight.to(x.dtype) if self.weight is not None else None, + self.bias.to(x.dtype) if self.bias is not None else None, + self.eps) def normalization(channels, num_groups=32): diff --git a/unitree_z1_dual_arm_cleanup_pencils/case1/output.log b/unitree_z1_dual_arm_cleanup_pencils/case1/output.log index 4ad0ac4..a8d5d16 100644 --- a/unitree_z1_dual_arm_cleanup_pencils/case1/output.log +++ b/unitree_z1_dual_arm_cleanup_pencils/case1/output.log @@ -1,12 +1,12 @@ -2026-02-08 13:59:02.578826: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`. -2026-02-08 13:59:02.581891: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used. -2026-02-08 13:59:02.613088: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered -2026-02-08 13:59:02.613125: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered -2026-02-08 13:59:02.614961: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered -2026-02-08 13:59:02.623180: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used. -2026-02-08 13:59:02.623460: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations. +2026-02-08 15:47:30.035545: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`. +2026-02-08 15:47:30.038628: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used. +2026-02-08 15:47:30.069635: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered +2026-02-08 15:47:30.069671: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered +2026-02-08 15:47:30.071534: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered +2026-02-08 15:47:30.080021: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used. +2026-02-08 15:47:30.080300: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations. To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags. -2026-02-08 13:59:03.306638: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT +2026-02-08 15:47:30.746161: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT [rank: 0] Global seed set to 123 /mnt/ASC1637/miniconda3/envs/unifolm-wma/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead. @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32) @@ -23,7 +23,7 @@ INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k). INFO:root:Loaded ViT-H-14 model config. DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0 INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k). -/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:149: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. +/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:183: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. state_dict = torch.load(ckpt, map_location="cpu") >>> model checkpoint loaded. >>> Load pre-trained model ... @@ -36,6 +36,8 @@ INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k). ✓ Projectors converted to bfloat16 ✓ Encoders converted to bfloat16 ✓ VAE converted to bfloat16 + ⚠ Found 601 fp32 params, converting to bf16 + ✓ All parameters converted to bfloat16 INFO:root:***** Configing Data ***** >>> unitree_z1_stackbox: 1 data samples loaded. >>> unitree_z1_stackbox: data stats loaded. @@ -60,10 +62,6 @@ DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9 DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096 0%| | 0/8 [00:00>> Step 0: generating actions ... >>> Step 0: interacting with world model ... DEBUG:PIL.Image:Importing BlpImagePlugin @@ -115,7 +113,7 @@ DEBUG:PIL.Image:Importing WmfImagePlugin DEBUG:PIL.Image:Importing XbmImagePlugin DEBUG:PIL.Image:Importing XpmImagePlugin DEBUG:PIL.Image:Importing XVThumbImagePlugin - 12%|█▎ | 1/8 [01:24<09:54, 84.90s/it] 25%|██▌ | 2/8 [02:49<08:27, 84.55s/it] 38%|███▊ | 3/8 [04:13<07:02, 84.46s/it] 50%|█████ | 4/8 [05:38<05:38, 84.50s/it] 62%|██████▎ | 5/8 [07:02<04:13, 84.52s/it] 75%|███████▌ | 6/8 [08:27<02:49, 84.52s/it] 88%|████████▊ | 7/8 [09:51<01:24, 84.44s/it] 100%|██████████| 8/8 [11:16<00:00, 84.47s/it] 100%|██████████| 8/8 [11:16<00:00, 84.50s/it] + 12%|█▎ | 1/8 [01:19<09:17, 79.58s/it] 25%|██▌ | 2/8 [02:38<07:54, 79.06s/it] 38%|███▊ | 3/8 [03:56<06:34, 78.87s/it] 50%|█████ | 4/8 [05:15<05:15, 78.85s/it] 62%|██████▎ | 5/8 [06:34<03:56, 78.84s/it] 75%|███████▌ | 6/8 [07:53<02:37, 78.81s/it] 88%|████████▊ | 7/8 [09:11<01:18, 78.71s/it] 100%|██████████| 8/8 [10:30<00:00, 78.66s/it] 100%|██████████| 8/8 [10:30<00:00, 78.80s/it] >>>>>>>>>>>>>>>>>>>>>>>> >>> Step 1: generating actions ... >>> Step 1: interacting with world model ... @@ -139,6 +137,6 @@ DEBUG:PIL.Image:Importing XVThumbImagePlugin >>> Step 7: interacting with world model ... >>>>>>>>>>>>>>>>>>>>>>>> -real 12m14.598s -user 12m18.424s -sys 0m45.306s +real 11m29.763s +user 12m56.891s +sys 0m55.414s diff --git a/unitree_z1_dual_arm_cleanup_pencils/case1/psnr_result1.json b/unitree_z1_dual_arm_cleanup_pencils/case1/psnr_result1.json index 05b95c4..73fc660 100644 --- a/unitree_z1_dual_arm_cleanup_pencils/case1/psnr_result1.json +++ b/unitree_z1_dual_arm_cleanup_pencils/case1/psnr_result1.json @@ -1,5 +1,5 @@ { "gt_video": "/mnt/ASC1637/unifolm-world-model-action/unitree_z1_dual_arm_cleanup_pencils/case1/output/inference/unitree_z1_dual_arm_cleanup_pencils_case1_amd.mp4", "pred_video": "/mnt/ASC1637/unifolm-world-model-action/unitree_z1_dual_arm_cleanup_pencils/case1/output/inference/0_full_fs4.mp4", - "psnr": 30.44844270035179 + "psnr": 30.24435361473318 } \ No newline at end of file diff --git a/unitree_z1_dual_arm_cleanup_pencils/case1/run_world_model_interaction.sh b/unitree_z1_dual_arm_cleanup_pencils/case1/run_world_model_interaction.sh index 304cb31..a7ad4bf 100644 --- a/unitree_z1_dual_arm_cleanup_pencils/case1/run_world_model_interaction.sh +++ b/unitree_z1_dual_arm_cleanup_pencils/case1/run_world_model_interaction.sh @@ -2,7 +2,7 @@ res_dir="unitree_z1_dual_arm_cleanup_pencils/case1" dataset="unitree_z1_dual_arm_cleanup_pencils" { - time CUDA_VISIBLE_DEVICES=0 python3 scripts/evaluation/world_model_interaction.py \ + time TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 CUDA_VISIBLE_DEVICES=0 python3 scripts/evaluation/world_model_interaction.py \ --seed 123 \ --ckpt_path ckpts/unifolm_wma_dual_mix_bf16.ckpt \ --config configs/inference/world_model_interaction.yaml \