1. einsum('b i d, b j d -> b i j') → torch.bmm(q, k.transpose(-1,-2)) — 直接映射 rocBLAS batched GEMM

2. baddbmm 把 scale 融合进 GEMM,少一次 kernel launch
3. 第二个 einsum 同理换torch.bm
每一轮加速1到两秒
This commit is contained in:
2026-02-08 18:54:48 +00:00
parent 7338cc384a
commit a2cd34dd51
5 changed files with 994 additions and 37 deletions

View File

@@ -118,4 +118,100 @@ SUMMARY
Matmul throughput: 33.11 TFLOPS/s (54.3% of BF16 peak)
Overall throughput: 18.87 TFLOPS/s (30.9% of BF16 peak)
GPU peak (BF16): 61.0 TFLOPS
(unifolm-wma) ASC1637@wx-ms-w7900d-0033:/mnt/ASC1637/unifolm-world-model-action$
(unifolm-wma) ASC1637@wx-ms-w7900d-0033:/mnt/ASC1637/unifolm-world-model-action$
========================================================================
TABLE 1: STAGE TIMING
========================================================================
Stage Mean(ms) Std %
------------------------------------------------------------------------
1_Image_Embedding 29.5 0.16 0.1%
2_VAE_Encode 51.3 0.06 0.1%
3_Text_Conditioning 14.7 0.18 0.0%
4_Projectors 0.2 0.03 0.0%
5_DDIM_Loop 33392.5 3.21 97.3%
6_VAE_Decode 808.4 1.00 2.4%
7_Post_Process 15.8 0.56 0.0%
------------------------------------------------------------------------
TOTAL 34312.4
================================================================================
TABLE 2: UNET SUB-MODULE BREAKDOWN
================================================================================
Module Type Total(ms) Count Per-call %
--------------------------------------------------------------------------------
ResBlock 10256.3 1100 9.32 23.2%
SpatialTransformer 9228.2 800 11.54 20.9%
CrossAttention 8105.8 3300 2.46 18.3%
ConditionalUnet1D 6409.5 100 64.10 14.5%
TemporalTransformer 5847.0 850 6.88 13.2%
FeedForward 4338.1 1650 2.63 9.8%
UNet.out 73.8 50 1.48 0.2%
--------------------------------------------------------------------------------
TOTAL (hooked) 44258.7
==========================================================================================
TABLE 2b: PER-UNET-BLOCK TIMING (coarse modules, per DDIM loop)
==========================================================================================
Block Total(ms) % Breakdown
------------------------------------------------------------------------------------------
input_blocks.1 3376.2 7.6% SpatialTransformer=1101, CrossAttention=990, ResBlock=543, TemporalTransformer=454, FeedForward=288
input_blocks.2 3374.0 7.6% SpatialTransformer=1100, CrossAttention=991, ResBlock=540, TemporalTransformer=455, FeedForward=288
input_blocks.4 1592.4 3.6% SpatialTransformer=394, ResBlock=374, CrossAttention=303, TemporalTransformer=272, FeedForward=249
input_blocks.5 1642.5 3.7% ResBlock=425, SpatialTransformer=397, CrossAttention=303, TemporalTransformer=271, FeedForward=247
input_blocks.7 1469.0 3.3% ResBlock=416, SpatialTransformer=324, FeedForward=251, CrossAttention=240, TemporalTransformer=237
input_blocks.8 1543.7 3.5% ResBlock=491, SpatialTransformer=325, FeedForward=250, CrossAttention=240, TemporalTransformer=238
input_blocks.10 217.5 0.5% ResBlock=218
input_blocks.11 216.8 0.5% ResBlock=217
middle_block 848.9 1.9% ResBlock=434, SpatialTransformer=151, CrossAttention=134, TemporalTransformer=69, FeedForward=61
output_blocks.0 303.2 0.7% ResBlock=303
output_blocks.1 303.1 0.7% ResBlock=303
output_blocks.2 302.8 0.7% ResBlock=303
output_blocks.3 1734.8 3.9% ResBlock=687, SpatialTransformer=322, FeedForward=249, CrossAttention=239, TemporalTransformer=237
output_blocks.4 1739.8 3.9% ResBlock=688, SpatialTransformer=323, FeedForward=251, CrossAttention=239, TemporalTransformer=238
output_blocks.5 1622.3 3.7% ResBlock=570, SpatialTransformer=324, FeedForward=251, CrossAttention=239, TemporalTransformer=238
output_blocks.6 1881.0 4.3% ResBlock=664, SpatialTransformer=393, CrossAttention=301, TemporalTransformer=272, FeedForward=250
output_blocks.7 1768.0 4.0% ResBlock=554, SpatialTransformer=393, CrossAttention=301, TemporalTransformer=272, FeedForward=249
output_blocks.8 1688.7 3.8% ResBlock=474, SpatialTransformer=393, CrossAttention=301, TemporalTransformer=272, FeedForward=249
output_blocks.9 3558.6 8.0% SpatialTransformer=1096, CrossAttention=992, ResBlock=727, TemporalTransformer=454, FeedForward=290
output_blocks.10 3492.8 7.9% SpatialTransformer=1096, CrossAttention=992, ResBlock=662, TemporalTransformer=454, FeedForward=289
output_blocks.11 3493.3 7.9% SpatialTransformer=1096, CrossAttention=992, ResBlock=662, TemporalTransformer=454, FeedForward=289
out 73.8 0.2% UNet.out=74
action_unet 3212.0 7.3% ConditionalUnet1D=3212
state_unet 3197.6 7.2% ConditionalUnet1D=3198
other 1606.2 3.6% TemporalTransformer=960, FeedForward=337, CrossAttention=309
------------------------------------------------------------------------------------------
TOTAL 44258.7
======================================================================
TABLE 2c: ATTENTION vs FEEDFORWARD (deep hooks)
======================================================================
Component Total(ms) %
----------------------------------------------------------------------
CrossAttention 8105.8 65.1%
FeedForward 4338.1 34.9%
----------------------------------------------------------------------
TOTAL (attn+ff) 12443.9
==================================================
TABLE 3: MEMORY SUMMARY
==================================================
Initial allocated: 11.82 GB
Peak allocated: 14.43 GB
Delta (pipeline): 2.61 GB
============================================================
TABLE 4: THROUGHPUT
============================================================
Total pipeline latency: 34312.4 ms
DDIM loop latency: 33392.5 ms
DDIM steps: 50
CFG scale: 1.0 (1x UNet/step)
UNet forward calls: 50
Per DDIM step: 667.9 ms
Per UNet forward: 667.9 ms
VAE encode bandwidth: 0.1 GB/s (peak HBM: 864.0 GB/s)
VAE decode bandwidth: 0.0 GB/s (peak HBM: 864.0 GB/s)
GPU BF16 peak: 61.0 TFLOPS
Done.

View File

@@ -0,0 +1,733 @@
"""
Profile the full inference pipeline of the world model, covering all 7 stages:
1. Image Embedding
2. VAE Encode
3. Text Conditioning
4. State/Action Projectors
5. DDIM Loop
6. VAE Decode
7. Post-process
Reports stage-level timing, UNet sub-module breakdown, memory summary,
and throughput analysis.
TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 CUDA_VISIBLE_DEVICES=0 python scripts/evaluation/profile_pipeline.py --ckpt_path ckpts/unifolm_wma_dual_mix_bf16.ckpt --config configs/inference/world_model_interaction.yaml --ddim_steps 50 --cfg_scale 1.0 --n_runs 3 --deep
Usage:
TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 CUDA_VISIBLE_DEVICES=0 python scripts/evaluation/profile_pipeline.py --ckpt_path ckpts/unifolm_wma_dual_mix_bf16.ckpt --config configs/inference/world_model_interaction.yaml --ddim_steps 50 --cfg_scale 1.0 --n_runs 3
"""
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint # must be loaded before unifolm_wma.utils.common
from contextlib import nullcontext, contextmanager
from collections import defaultdict
from omegaconf import OmegaConf
from einops import rearrange, repeat
from unifolm_wma.utils.utils import instantiate_from_config
from unifolm_wma.models.samplers.ddim import DDIMSampler
from unifolm_wma.modules.attention import (
SpatialTransformer, TemporalTransformer,
BasicTransformerBlock, CrossAttention, FeedForward,
)
from unifolm_wma.modules.networks.wma_model import ResBlock
from unifolm_wma.models.diffusion_head.conditional_unet1d import ConditionalUnet1D
# --- W7900D theoretical peak ---
PEAK_BF16_TFLOPS = 61.0
MEM_BW_GBS = 864.0
# ---------------------------------------------------------------------------
# Utility: patch norms to bypass autocast fp32 promotion
# ---------------------------------------------------------------------------
def patch_norm_bypass_autocast():
"""Monkey-patch GroupNorm and LayerNorm to bypass autocast's fp32 policy."""
def _group_norm_forward(self, x):
with torch.amp.autocast('cuda', enabled=False):
return F.group_norm(
x, self.num_groups,
self.weight.to(x.dtype) if self.weight is not None else None,
self.bias.to(x.dtype) if self.bias is not None else None,
self.eps)
def _layer_norm_forward(self, x):
with torch.amp.autocast('cuda', enabled=False):
return F.layer_norm(
x, self.normalized_shape,
self.weight.to(x.dtype) if self.weight is not None else None,
self.bias.to(x.dtype) if self.bias is not None else None,
self.eps)
torch.nn.GroupNorm.forward = _group_norm_forward
torch.nn.LayerNorm.forward = _layer_norm_forward
# ---------------------------------------------------------------------------
# Utility: torch.compile hot ResBlocks
# ---------------------------------------------------------------------------
def apply_torch_compile(model, hot_indices=(5, 8, 9)):
"""Compile ResBlock._forward in the hottest output_blocks for operator fusion."""
unet = model.model.diffusion_model
compiled = 0
for idx in hot_indices:
block = unet.output_blocks[idx]
for layer in block:
if isinstance(layer, ResBlock):
layer._forward = torch.compile(layer._forward, mode="default")
compiled += 1
print(f" torch.compile: {compiled} ResBlocks in output_blocks{list(hot_indices)}")
# ---------------------------------------------------------------------------
# Model loading
# ---------------------------------------------------------------------------
def load_model(args):
config = OmegaConf.load(args.config)
config['model']['params']['wma_config']['params']['use_checkpoint'] = False
model = instantiate_from_config(config.model)
state_dict = torch.load(args.ckpt_path, map_location="cpu")
if "state_dict" in state_dict:
state_dict = state_dict["state_dict"]
model.load_state_dict(state_dict, strict=True)
model.eval()
model.model.to(torch.bfloat16)
model.diffusion_autocast_dtype = torch.bfloat16
apply_torch_compile(model)
model = model.cuda()
return model
# ---------------------------------------------------------------------------
# CudaTimer — precise GPU timing via CUDA events
# ---------------------------------------------------------------------------
class CudaTimer:
"""Context manager for GPU-precise stage timing using CUDA events."""
def __init__(self, name, records):
self.name = name
self.records = records
self.start = torch.cuda.Event(enable_timing=True)
self.end = torch.cuda.Event(enable_timing=True)
def __enter__(self):
torch.cuda.synchronize()
self.start.record()
return self
def __exit__(self, *args):
self.end.record()
torch.cuda.synchronize()
elapsed = self.start.elapsed_time(self.end)
self.records[self.name].append(elapsed)
# ---------------------------------------------------------------------------
# HookProfiler — sub-module level timing inside UNet via hooks
# ---------------------------------------------------------------------------
class HookProfiler:
"""Register forward hooks on UNet sub-modules to collect per-call timing."""
# Coarse-grained targets (original)
COARSE_CLASSES = (
SpatialTransformer,
TemporalTransformer,
ResBlock,
ConditionalUnet1D,
)
# Fine-grained targets for deep DDIM analysis
FINE_CLASSES = (
CrossAttention,
FeedForward,
)
def __init__(self, unet, deep=False):
self.unet = unet
self.deep = deep
self.handles = []
# per-instance data: {instance_id: [(start_event, end_event), ...]}
self._events = defaultdict(list)
# tag mapping: {instance_id: (class_name, module_name)}
self._tags = {}
# block location: {instance_id: block_location_str}
self._block_loc = {}
@staticmethod
def _get_block_location(name):
"""Derive UNet block location from module name, e.g. 'input_blocks.3.1'."""
parts = name.split('.')
if len(parts) >= 2 and parts[0] == 'input_blocks':
return f"input_blocks.{parts[1]}"
elif len(parts) >= 1 and parts[0] == 'middle_block':
return "middle_block"
elif len(parts) >= 2 and parts[0] == 'output_blocks':
return f"output_blocks.{parts[1]}"
elif 'action_unet' in name:
return "action_unet"
elif 'state_unet' in name:
return "state_unet"
elif name == 'out' or name.startswith('out.'):
return "out"
return "other"
def register(self):
"""Attach pre/post forward hooks to target sub-modules + unet.out."""
target_classes = self.COARSE_CLASSES
if self.deep:
target_classes = target_classes + self.FINE_CLASSES
for name, mod in self.unet.named_modules():
if isinstance(mod, target_classes):
tag = type(mod).__name__
inst_id = id(mod)
self._tags[inst_id] = (tag, name)
self._block_loc[inst_id] = self._get_block_location(name)
self.handles.append(
mod.register_forward_pre_hook(self._make_pre_hook(inst_id)))
self.handles.append(
mod.register_forward_hook(self._make_post_hook(inst_id)))
# Also hook unet.out (nn.Sequential)
out_mod = self.unet.out
inst_id = id(out_mod)
self._tags[inst_id] = ("UNet.out", "out")
self._block_loc[inst_id] = "out"
self.handles.append(
out_mod.register_forward_pre_hook(self._make_pre_hook(inst_id)))
self.handles.append(
out_mod.register_forward_hook(self._make_post_hook(inst_id)))
def _make_pre_hook(self, inst_id):
events = self._events
def hook(module, input):
start = torch.cuda.Event(enable_timing=True)
start.record()
events[inst_id].append([start, None])
return hook
def _make_post_hook(self, inst_id):
events = self._events
def hook(module, input, output):
end = torch.cuda.Event(enable_timing=True)
end.record()
events[inst_id][-1][1] = end
return hook
def reset(self):
"""Clear collected events for a fresh run."""
self._events.clear()
def synchronize_and_collect(self):
"""Sync GPU and compute elapsed times. Returns (by_type, by_instance, by_block)."""
torch.cuda.synchronize()
by_type = defaultdict(lambda: {"total_ms": 0.0, "count": 0, "calls": []})
by_instance = {}
# by_block: {block_loc: {tag: {"total_ms", "count"}}}
by_block = defaultdict(lambda: defaultdict(lambda: {"total_ms": 0.0, "count": 0}))
for inst_id, pairs in self._events.items():
tag, mod_name = self._tags[inst_id]
block_loc = self._block_loc.get(inst_id, "other")
inst_times = []
for start_evt, end_evt in pairs:
if end_evt is not None:
ms = start_evt.elapsed_time(end_evt)
inst_times.append(ms)
by_type[tag]["total_ms"] += ms
by_type[tag]["count"] += 1
by_type[tag]["calls"].append(ms)
by_block[block_loc][tag]["total_ms"] += ms
by_block[block_loc][tag]["count"] += 1
by_instance[(tag, mod_name)] = inst_times
return dict(by_type), by_instance, dict(by_block)
def remove(self):
"""Remove all hooks."""
for h in self.handles:
h.remove()
self.handles.clear()
# ---------------------------------------------------------------------------
# Build dummy inputs matching the pipeline's expected shapes
# ---------------------------------------------------------------------------
def build_dummy_inputs(model, noise_shape):
"""Create synthetic observation dict and prompts for profiling."""
device = next(model.parameters()).device
B, C, T, H, W = noise_shape
dtype = torch.bfloat16
# observation.images.top: [B, C, O, H, W] (permuted to [B,O,C,H,W] inside pipeline)
O = 2
obs_images = torch.randn(B, 3, O, 320, 512, device=device, dtype=dtype)
obs_state = torch.randn(B, O, 16, device=device, dtype=dtype)
action = torch.randn(B, 16, 16, device=device, dtype=dtype)
observation = {
'observation.images.top': obs_images,
'observation.state': obs_state,
'action': action,
}
prompts = ["a robot arm performing a task"] * B
return observation, prompts
# ---------------------------------------------------------------------------
# Run one full pipeline pass with per-stage timing
# ---------------------------------------------------------------------------
def run_pipeline(model, observation, prompts, noise_shape, ddim_steps,
cfg_scale, hook_profiler):
"""Execute the full 7-stage pipeline, returning per-stage timing dict."""
records = defaultdict(list)
device = next(model.parameters()).device
B, C, T, H, W = noise_shape
dtype = torch.bfloat16
fs = torch.tensor([1] * B, dtype=torch.long, device=device)
# --- Stage 1: Image Embedding ---
with CudaTimer("1_Image_Embedding", records):
img = observation['observation.images.top'].permute(0, 2, 1, 3, 4)
cond_img = rearrange(img, 'b o c h w -> (b o) c h w')[-1:].to(dtype=dtype)
with torch.autocast('cuda', dtype=torch.bfloat16):
cond_img_emb = model.embedder(cond_img)
cond_img_emb = model.image_proj_model(cond_img_emb)
# --- Stage 2: VAE Encode ---
with CudaTimer("2_VAE_Encode", records):
videos = img.permute(0, 2, 1, 3, 4) # [B, C, O, H, W]
b_v, c_v, t_v, h_v, w_v = videos.shape
vae_dtype = next(model.first_stage_model.parameters()).dtype
x_vae = rearrange(videos, 'b c t h w -> (b t) c h w').to(dtype=vae_dtype)
z = model.encode_first_stage(x_vae)
z = rearrange(z, '(b t) c h w -> b c t h w', b=b_v, t=t_v)
img_cat_cond = z[:, :, -1:, :, :]
img_cat_cond = repeat(img_cat_cond,
'b c t h w -> b c (repeat t) h w', repeat=T)
cond = {"c_concat": [img_cat_cond]}
vae_enc_input_bytes = x_vae.nelement() * x_vae.element_size()
vae_enc_output_bytes = z.nelement() * z.element_size()
# --- Stage 3: Text Conditioning ---
with CudaTimer("3_Text_Conditioning", records):
cond_ins_emb = model.get_learned_conditioning(prompts)
# --- Stage 4: State/Action Projectors ---
with CudaTimer("4_Projectors", records):
projector_dtype = next(model.state_projector.parameters()).dtype
with torch.autocast('cuda', dtype=torch.bfloat16):
cond_state_emb = model.state_projector(
observation['observation.state'].to(dtype=projector_dtype))
cond_state_emb = cond_state_emb + model.agent_state_pos_emb
cond_action_emb = model.action_projector(
observation['action'].to(dtype=projector_dtype))
cond_action_emb = cond_action_emb + model.agent_action_pos_emb
# Assemble cross-attention conditioning
cond["c_crossattn"] = [
torch.cat([cond_state_emb, cond_action_emb, cond_ins_emb, cond_img_emb],
dim=1)
]
n_obs_acting = getattr(model, 'n_obs_steps_acting', 2)
cond["c_crossattn_action"] = [
observation['observation.images.top'][:, :, -n_obs_acting:],
observation['observation.state'][:, -n_obs_acting:],
True, # sim_mode
False,
]
# CFG: build unconditional conditioning if needed
uc = None
if cfg_scale != 1.0:
uc_crossattn = torch.zeros_like(cond["c_crossattn"][0])
uc = {
"c_concat": cond["c_concat"],
"c_crossattn": [uc_crossattn],
"c_crossattn_action": cond["c_crossattn_action"],
}
# --- Stage 5: DDIM Loop ---
ddim_sampler = DDIMSampler(model)
hook_profiler.reset()
with CudaTimer("5_DDIM_Loop", records):
with torch.autocast('cuda', dtype=torch.bfloat16):
samples, actions, states, _ = ddim_sampler.sample(
S=ddim_steps,
conditioning=cond,
batch_size=B,
shape=noise_shape[1:],
verbose=False,
unconditional_guidance_scale=cfg_scale,
unconditional_conditioning=uc,
eta=1.0,
cfg_img=None,
mask=None,
x0=None,
fs=fs,
timestep_spacing='uniform',
guidance_rescale=0.0,
unconditional_conditioning_img_nonetext=None,
)
hook_by_type, hook_by_instance, hook_by_block = hook_profiler.synchronize_and_collect()
# --- Stage 6: VAE Decode ---
with CudaTimer("6_VAE_Decode", records):
batch_images = model.decode_first_stage(samples)
vae_dec_input_bytes = samples.nelement() * samples.element_size()
vae_dec_output_bytes = batch_images.nelement() * batch_images.element_size()
# --- Stage 7: Post-process ---
with CudaTimer("7_Post_Process", records):
batch_images_cpu = batch_images.cpu()
actions_cpu = actions.cpu()
states_cpu = states.cpu()
# Simulate video save overhead: clamp + uint8 conversion
_ = (batch_images_cpu.clamp(-1, 1) * 127.5 + 127.5).to(torch.uint8)
# Flatten single-element lists
stage_times = {k: v[0] for k, v in records.items()}
bandwidth_info = {
"vae_enc_input_bytes": vae_enc_input_bytes,
"vae_enc_output_bytes": vae_enc_output_bytes,
"vae_dec_input_bytes": vae_dec_input_bytes,
"vae_dec_output_bytes": vae_dec_output_bytes,
}
return stage_times, hook_by_type, hook_by_instance, hook_by_block, bandwidth_info
# ---------------------------------------------------------------------------
# Reporting
# ---------------------------------------------------------------------------
def print_stage_timing(all_runs_stages):
"""Table 1: Stage Timing — name | mean(ms) | std | percent."""
import numpy as np
stage_names = list(all_runs_stages[0].keys())
means = {}
stds = {}
for name in stage_names:
vals = [run[name] for run in all_runs_stages]
means[name] = np.mean(vals)
stds[name] = np.std(vals)
total = sum(means.values())
print()
print("=" * 72)
print("TABLE 1: STAGE TIMING")
print("=" * 72)
print(f"{'Stage':<25} {'Mean(ms)':>10} {'Std':>10} {'%':>8}")
print("-" * 72)
for name in stage_names:
pct = means[name] / total * 100 if total > 0 else 0
print(f"{name:<25} {means[name]:>10.1f} {stds[name]:>10.2f} {pct:>7.1f}%")
print("-" * 72)
print(f"{'TOTAL':<25} {total:>10.1f}")
print()
def print_unet_breakdown(all_runs_hooks):
"""Table 2: UNet Sub-Module Breakdown — type | total | count | per-call | percent."""
import numpy as np
# Aggregate across runs
agg = defaultdict(lambda: {"totals": [], "counts": []})
for hook_by_type in all_runs_hooks:
for tag, data in hook_by_type.items():
agg[tag]["totals"].append(data["total_ms"])
agg[tag]["counts"].append(data["count"])
print("=" * 80)
print("TABLE 2: UNET SUB-MODULE BREAKDOWN")
print("=" * 80)
print(f"{'Module Type':<25} {'Total(ms)':>10} {'Count':>7} {'Per-call':>10} {'%':>8}")
print("-" * 80)
grand_total = 0
rows = []
for tag, d in agg.items():
mean_total = np.mean(d["totals"])
mean_count = np.mean(d["counts"])
per_call = mean_total / mean_count if mean_count > 0 else 0
grand_total += mean_total
rows.append((tag, mean_total, mean_count, per_call))
rows.sort(key=lambda r: r[1], reverse=True)
for tag, mean_total, mean_count, per_call in rows:
pct = mean_total / grand_total * 100 if grand_total > 0 else 0
print(f"{tag:<25} {mean_total:>10.1f} {int(mean_count):>7} {per_call:>10.2f} {pct:>7.1f}%")
print("-" * 80)
print(f"{'TOTAL (hooked)':<25} {grand_total:>10.1f}")
print()
def print_block_timing(all_runs_blocks):
"""Table 2b: Per-UNet-block timing — which blocks are hottest."""
import numpy as np
# Aggregate: {block_loc: {tag: [total_ms_per_run, ...]}}
agg = defaultdict(lambda: defaultdict(list))
for by_block in all_runs_blocks:
for block_loc, tag_dict in by_block.items():
for tag, data in tag_dict.items():
agg[block_loc][tag].append(data["total_ms"])
# Compute per-block totals
block_totals = {}
for block_loc, tag_dict in agg.items():
block_totals[block_loc] = sum(np.mean(v) for v in tag_dict.values())
grand_total = sum(block_totals.values())
# Sort blocks in logical order
def block_sort_key(name):
if name.startswith("input_blocks."):
return (0, int(name.split('.')[1]))
elif name == "middle_block":
return (1, 0)
elif name.startswith("output_blocks."):
return (2, int(name.split('.')[1]))
elif name == "out":
return (3, 0)
elif name == "action_unet":
return (4, 0)
elif name == "state_unet":
return (5, 0)
return (9, 0)
sorted_blocks = sorted(block_totals.keys(), key=block_sort_key)
print("=" * 90)
print("TABLE 2b: PER-UNET-BLOCK TIMING (coarse modules, per DDIM loop)")
print("=" * 90)
print(f"{'Block':<22} {'Total(ms)':>10} {'%':>7} Breakdown")
print("-" * 90)
for block_loc in sorted_blocks:
total = block_totals[block_loc]
pct = total / grand_total * 100 if grand_total > 0 else 0
# Build breakdown string
parts = []
for tag, vals in sorted(agg[block_loc].items(),
key=lambda x: np.mean(x[1]), reverse=True):
parts.append(f"{tag}={np.mean(vals):.0f}")
breakdown = ", ".join(parts)
print(f"{block_loc:<22} {total:>10.1f} {pct:>6.1f}% {breakdown}")
print("-" * 90)
print(f"{'TOTAL':<22} {grand_total:>10.1f}")
print()
def print_attn_ff_breakdown(all_runs_hooks):
"""Table 2c: CrossAttention vs FeedForward breakdown (--deep mode)."""
import numpy as np
agg = defaultdict(list)
for hook_by_type in all_runs_hooks:
for tag, data in hook_by_type.items():
if tag in ("CrossAttention", "FeedForward"):
agg[tag].append(data["total_ms"])
if not agg:
return
print("=" * 70)
print("TABLE 2c: ATTENTION vs FEEDFORWARD (deep hooks)")
print("=" * 70)
print(f"{'Component':<25} {'Total(ms)':>10} {'%':>8}")
print("-" * 70)
grand = 0
rows = []
for tag in ("CrossAttention", "FeedForward"):
if tag in agg:
mean_t = np.mean(agg[tag])
grand += mean_t
rows.append((tag, mean_t))
for tag, mean_t in rows:
pct = mean_t / grand * 100 if grand > 0 else 0
print(f"{tag:<25} {mean_t:>10.1f} {pct:>7.1f}%")
print("-" * 70)
print(f"{'TOTAL (attn+ff)':<25} {grand:>10.1f}")
print()
def print_unet_detailed(all_runs_instances):
"""Print per-instance UNet sub-module detail (--detailed mode)."""
import numpy as np
# Use last run's data
by_instance = all_runs_instances[-1]
print("=" * 100)
print("DETAILED: PER-INSTANCE UNET SUB-MODULE TIMING (last run)")
print("=" * 100)
print(f"{'Type':<22} {'Module Name':<45} {'Calls':>6} {'Total(ms)':>10} {'Mean(ms)':>10}")
print("-" * 100)
rows = []
for (tag, mod_name), times in by_instance.items():
if len(times) == 0:
continue
total = sum(times)
mean = np.mean(times)
rows.append((tag, mod_name, len(times), total, mean))
rows.sort(key=lambda r: r[3], reverse=True)
for tag, mod_name, count, total, mean in rows:
short_name = mod_name[-42:] if len(mod_name) > 42 else mod_name
print(f"{tag:<22} {short_name:<45} {count:>6} {total:>10.2f} {mean:>10.3f}")
print()
def print_memory_summary(mem_before, mem_peak):
"""Table 3: Memory Summary."""
delta = mem_peak - mem_before
print("=" * 50)
print("TABLE 3: MEMORY SUMMARY")
print("=" * 50)
print(f" Initial allocated: {mem_before / 1e9:.2f} GB")
print(f" Peak allocated: {mem_peak / 1e9:.2f} GB")
print(f" Delta (pipeline): {delta / 1e9:.2f} GB")
print()
def print_throughput(all_runs_stages, all_bw, ddim_steps, cfg_scale):
"""Table 4: Throughput — total latency, per-step, per-UNet-forward, VAE bandwidth."""
import numpy as np
n_runs = len(all_runs_stages)
# Total latency
totals = []
for run in all_runs_stages:
totals.append(sum(run.values()))
mean_total = np.mean(totals)
# DDIM loop time
ddim_times = [run["5_DDIM_Loop"] for run in all_runs_stages]
mean_ddim = np.mean(ddim_times)
unet_calls = ddim_steps if cfg_scale == 1.0 else ddim_steps * 2
per_step = mean_ddim / ddim_steps
per_unet = mean_ddim / unet_calls
# VAE bandwidth
mean_enc_time = np.mean([run["2_VAE_Encode"] for run in all_runs_stages])
mean_dec_time = np.mean([run["6_VAE_Decode"] for run in all_runs_stages])
bw = all_bw[-1] # use last run's byte counts
enc_bytes = bw["vae_enc_input_bytes"] + bw["vae_enc_output_bytes"]
dec_bytes = bw["vae_dec_input_bytes"] + bw["vae_dec_output_bytes"]
enc_bw = enc_bytes / (mean_enc_time / 1000) / 1e9 if mean_enc_time > 0 else 0
dec_bw = dec_bytes / (mean_dec_time / 1000) / 1e9 if mean_dec_time > 0 else 0
print("=" * 60)
print("TABLE 4: THROUGHPUT")
print("=" * 60)
print(f" Total pipeline latency: {mean_total:.1f} ms")
print(f" DDIM loop latency: {mean_ddim:.1f} ms")
print(f" DDIM steps: {ddim_steps}")
print(f" CFG scale: {cfg_scale} ({'2x UNet/step' if cfg_scale != 1.0 else '1x UNet/step'})")
print(f" UNet forward calls: {unet_calls}")
print(f" Per DDIM step: {per_step:.1f} ms")
print(f" Per UNet forward: {per_unet:.1f} ms")
print(f" VAE encode bandwidth: {enc_bw:.1f} GB/s (peak HBM: {MEM_BW_GBS} GB/s)")
print(f" VAE decode bandwidth: {dec_bw:.1f} GB/s (peak HBM: {MEM_BW_GBS} GB/s)")
print(f" GPU BF16 peak: {PEAK_BF16_TFLOPS} TFLOPS")
print()
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main():
patch_norm_bypass_autocast()
parser = argparse.ArgumentParser(
description="Profile the full inference pipeline")
parser.add_argument("--ckpt_path", type=str, required=True)
parser.add_argument("--config", type=str, required=True)
parser.add_argument("--ddim_steps", type=int, default=50)
parser.add_argument("--cfg_scale", type=float, default=1.0)
parser.add_argument("--n_runs", type=int, default=3)
parser.add_argument("--warmup", type=int, default=1)
parser.add_argument("--detailed", action="store_true",
help="Print per-instance UNet sub-module detail")
parser.add_argument("--deep", action="store_true",
help="Enable deep DDIM analysis: per-block, attn vs ff")
args = parser.parse_args()
noise_shape = [1, 4, 16, 40, 64]
# --- Load model ---
print("Loading model...")
model = load_model(args)
observation, prompts = build_dummy_inputs(model, noise_shape)
# --- Setup hook profiler ---
unet = model.model.diffusion_model
hook_profiler = HookProfiler(unet, deep=args.deep)
hook_profiler.register()
print(f"Registered hooks on {len(hook_profiler.handles)} sub-modules")
# --- Warmup ---
print(f"Warmup: {args.warmup} run(s)...")
with torch.no_grad():
for i in range(args.warmup):
run_pipeline(model, observation, prompts, noise_shape,
args.ddim_steps, args.cfg_scale, hook_profiler)
print(f" warmup {i+1}/{args.warmup} done")
# --- Measurement runs ---
print(f"Measuring: {args.n_runs} run(s)...")
torch.cuda.reset_peak_memory_stats()
mem_before = torch.cuda.memory_allocated()
all_stages = []
all_hooks = []
all_instances = []
all_blocks = []
all_bw = []
with torch.no_grad():
for i in range(args.n_runs):
stage_times, hook_by_type, hook_by_instance, hook_by_block, bw = run_pipeline(
model, observation, prompts, noise_shape,
args.ddim_steps, args.cfg_scale, hook_profiler)
all_stages.append(stage_times)
all_hooks.append(hook_by_type)
all_instances.append(hook_by_instance)
all_blocks.append(hook_by_block)
all_bw.append(bw)
total = sum(stage_times.values())
print(f" run {i+1}/{args.n_runs}: {total:.1f} ms total")
mem_peak = torch.cuda.max_memory_allocated()
# --- Reports ---
print_stage_timing(all_stages)
print_unet_breakdown(all_hooks)
print_block_timing(all_blocks)
if args.deep:
print_attn_ff_breakdown(all_hooks)
if args.detailed:
print_unet_detailed(all_instances)
print_memory_summary(mem_before, mem_peak)
print_throughput(all_stages, all_bw, args.ddim_steps, args.cfg_scale)
hook_profiler.remove()
print("Done.")
if __name__ == "__main__":
main()

View File

@@ -86,9 +86,8 @@ class CrossAttention(nn.Module):
self.relative_position_v = RelativePosition(
num_units=dim_head, max_relative_position=temporal_length)
else:
## only used for spatial attention, while NOT for temporal attention
if XFORMERS_IS_AVAILBLE and temporal_length is None:
self.forward = self.efficient_forward
## bmm fused-scale attention for all non-relative-position cases
self.forward = self.bmm_forward
self.video_length = video_length
self.image_cross_attention = image_cross_attention
@@ -234,6 +233,119 @@ class CrossAttention(nn.Module):
return self.to_out(out)
def bmm_forward(self, x, context=None, mask=None):
spatial_self_attn = (context is None)
k_ip, v_ip, out_ip = None, None, None
k_as, v_as, out_as = None, None, None
k_aa, v_aa, out_aa = None, None, None
h = self.heads
q = self.to_q(x)
context = default(context, x)
if self.image_cross_attention and not spatial_self_attn:
context_agent_state = context[:, :self.agent_state_context_len, :]
context_agent_action = context[:,
self.agent_state_context_len:self.
agent_state_context_len +
self.agent_action_context_len, :]
context_ins = context[:, self.agent_state_context_len +
self.agent_action_context_len:self.
agent_state_context_len +
self.agent_action_context_len +
self.text_context_len, :]
context_image = context[:, self.agent_state_context_len +
self.agent_action_context_len +
self.text_context_len:, :]
k = self.to_k(context_ins)
v = self.to_v(context_ins)
k_ip = self.to_k_ip(context_image)
v_ip = self.to_v_ip(context_image)
k_as = self.to_k_as(context_agent_state)
v_as = self.to_v_as(context_agent_state)
k_aa = self.to_k_aa(context_agent_action)
v_aa = self.to_v_aa(context_agent_action)
else:
if not spatial_self_attn:
context = context[:, :self.text_context_len, :]
k = self.to_k(context)
v = self.to_v(context)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
(q, k, v))
# baddbmm: fuse scale into GEMM → one kernel instead of matmul + mul
sim = torch.baddbmm(
torch.empty(q.shape[0], q.shape[1], k.shape[1], dtype=q.dtype, device=q.device),
q, k.transpose(-1, -2), beta=0, alpha=self.scale)
del k
if exists(mask):
max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, 'b i j -> (b h) i j', h=h)
sim.masked_fill_(~(mask > 0.5), max_neg_value)
with torch.amp.autocast('cuda', enabled=False):
sim = sim.softmax(dim=-1)
out = torch.bmm(sim, v)
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
if k_ip is not None and k_as is not None and k_aa is not None:
## image cross-attention
k_ip, v_ip = map(
lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
(k_ip, v_ip))
sim_ip = torch.baddbmm(
torch.empty(q.shape[0], q.shape[1], k_ip.shape[1], dtype=q.dtype, device=q.device),
q, k_ip.transpose(-1, -2), beta=0, alpha=self.scale)
del k_ip
with torch.amp.autocast('cuda', enabled=False):
sim_ip = sim_ip.softmax(dim=-1)
out_ip = torch.bmm(sim_ip, v_ip)
out_ip = rearrange(out_ip, '(b h) n d -> b n (h d)', h=h)
## agent state cross-attention
k_as, v_as = map(
lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
(k_as, v_as))
sim_as = torch.baddbmm(
torch.empty(q.shape[0], q.shape[1], k_as.shape[1], dtype=q.dtype, device=q.device),
q, k_as.transpose(-1, -2), beta=0, alpha=self.scale)
del k_as
with torch.amp.autocast('cuda', enabled=False):
sim_as = sim_as.softmax(dim=-1)
out_as = torch.bmm(sim_as, v_as)
out_as = rearrange(out_as, '(b h) n d -> b n (h d)', h=h)
## agent action cross-attention
k_aa, v_aa = map(
lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
(k_aa, v_aa))
sim_aa = torch.baddbmm(
torch.empty(q.shape[0], q.shape[1], k_aa.shape[1], dtype=q.dtype, device=q.device),
q, k_aa.transpose(-1, -2), beta=0, alpha=self.scale)
del k_aa
with torch.amp.autocast('cuda', enabled=False):
sim_aa = sim_aa.softmax(dim=-1)
out_aa = torch.bmm(sim_aa, v_aa)
out_aa = rearrange(out_aa, '(b h) n d -> b n (h d)', h=h)
if out_ip is not None and out_as is not None and out_aa is not None:
if self.cross_attention_scale_learnable:
out = out + \
self.image_cross_attention_scale * out_ip * (torch.tanh(self.alpha_ctx) + 1) + \
self.agent_state_cross_attention_scale * out_as * (torch.tanh(self.alpha_cas) + 1) + \
self.agent_action_cross_attention_scale * out_aa * (torch.tanh(self.alpha_caa) + 1)
else:
out = out + \
self.image_cross_attention_scale * out_ip + \
self.agent_state_cross_attention_scale * out_as + \
self.agent_action_cross_attention_scale * out_aa
return self.to_out(out)
def efficient_forward(self, x, context=None, mask=None):
spatial_self_attn = (context is None)
k, v, out = None, None, None

View File

@@ -1,14 +1,16 @@
2026-02-08 05:06:45.806187: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-02-08 05:06:45.809295: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-08 05:06:45.840950: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2026-02-08 05:06:45.840981: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2026-02-08 05:06:45.842814: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2026-02-08 05:06:45.851049: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-08 05:06:45.851316: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
__import__("pkg_resources").declare_namespace(__name__)
2026-02-08 18:28:48.960238: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-02-08 18:28:48.963331: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-08 18:28:48.995688: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2026-02-08 18:28:48.995732: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2026-02-08 18:28:48.997547: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2026-02-08 18:28:49.005673: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-08 18:28:49.005948: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2026-02-08 05:06:47.225477: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
2026-02-08 18:28:50.009660: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
[rank: 0] Global seed set to 123
/mnt/ASC1637/miniconda3/envs/unifolm-wma/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
@@ -18,15 +20,27 @@ INFO:root:Loaded ViT-H-14 model config.
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
/mnt/ASC1637/miniconda3/envs/unifolm-wma/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
checkpoint = torch.load(checkpoint_path, map_location=map_location)
INFO:root:Loaded ViT-H-14 model config.
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:86: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:198: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
state_dict = torch.load(ckpt, map_location="cpu")
>>> model checkpoint loaded.
>>> Load pre-trained model ...
>>> Applying precision settings:
- Diffusion dtype: bf16
- Projector mode: bf16_full
- Encoder mode: bf16_full
- VAE dtype: fp32
✓ Diffusion model weights converted to bfloat16
✓ Projectors converted to bfloat16
✓ Encoders converted to bfloat16
✓ VAE kept in fp32 for best quality
⚠ Found 849 fp32 params, converting to bf16
✓ All parameters converted to bfloat16
✓ torch.compile: 3 ResBlocks in output_blocks[5, 8, 9]
INFO:root:***** Configing Data *****
>>> unitree_z1_stackbox: 1 data samples loaded.
>>> unitree_z1_stackbox: data stats loaded.
@@ -49,11 +63,11 @@ DEBUG:h5py._conv:Creating converter from 3 to 5
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
0%| | 0/11 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
proj = linear(q, w, b)
proj = linear(q, w, b)
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Flash attention support on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:225.)
attn_output = scaled_dot_product_attention(
attn_output = scaled_dot_product_attention(
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)
attn_output = scaled_dot_product_attention(
>>> Step 0: generating actions ...
@@ -106,7 +120,7 @@ DEBUG:PIL.Image:Importing WmfImagePlugin
DEBUG:PIL.Image:Importing WmfImagePlugin
DEBUG:PIL.Image:Importing XbmImagePlugin
DEBUG:PIL.Image:Importing XpmImagePlugin
DEBUG:PIL.Image:Importing XVThumbImagePlugin
DEBUG:PIL.Image:Importing XVThumbImagePlugin
9%|▉ | 1/11 [01:14<12:29, 74.95s/it]
18%|█▊ | 2/11 [02:23<10:40, 71.18s/it]
@@ -139,6 +153,6 @@ DEBUG:PIL.Image:Importing XVThumbImagePlugin
>>> Step 6: generating actions ...
>>> Step 6: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 7: generating actions ...
>>> Step 7: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 7: generating actions ...
>>> Step 7: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>

View File

@@ -1,14 +1,16 @@
2026-02-08 16:49:41.598605: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-02-08 16:49:41.601687: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-08 16:49:41.632954: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2026-02-08 16:49:41.632986: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2026-02-08 16:49:41.634849: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2026-02-08 16:49:41.643134: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-08 16:49:41.643414: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
__import__("pkg_resources").declare_namespace(__name__)
2026-02-08 18:43:46.463492: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-02-08 18:43:46.466714: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-08 18:43:46.498994: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2026-02-08 18:43:46.499029: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2026-02-08 18:43:46.500865: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2026-02-08 18:43:46.509069: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-08 18:43:46.509359: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2026-02-08 16:49:42.320864: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
2026-02-08 18:43:47.434136: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
[rank: 0] Global seed set to 123
/mnt/ASC1637/miniconda3/envs/unifolm-wma/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
@@ -18,7 +20,7 @@ INFO:root:Loaded ViT-H-14 model config.
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
/mnt/ASC1637/miniconda3/envs/unifolm-wma/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
checkpoint = torch.load(checkpoint_path, map_location=map_location)
INFO:root:Loaded ViT-H-14 model config.
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
@@ -61,7 +63,7 @@ DEBUG:h5py._conv:Creating converter from 3 to 5
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
0%| | 0/8 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
proj = linear(q, w, b)
>>> Step 0: generating actions ...
@@ -114,7 +116,7 @@ DEBUG:PIL.Image:Importing WmfImagePlugin
DEBUG:PIL.Image:Importing WmfImagePlugin
DEBUG:PIL.Image:Importing XbmImagePlugin
DEBUG:PIL.Image:Importing XpmImagePlugin
DEBUG:PIL.Image:Importing XVThumbImagePlugin
DEBUG:PIL.Image:Importing XVThumbImagePlugin
12%|█▎ | 1/8 [01:12<08:27, 72.57s/it]
25%|██▌ | 2/8 [02:21<07:02, 70.44s/it]
@@ -138,6 +140,6 @@ DEBUG:PIL.Image:Importing XVThumbImagePlugin
>>> Step 4: generating actions ...
>>> Step 4: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 5: generating actions ...
>>> Step 5: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 5: generating actions ...
>>> Step 5: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>