10 Commits

Author SHA1 Message Date
qhy
bb274870c2 整理代码 2026-02-10 12:46:12 +08:00
qhy
f1f92072e6 remove profile 2026-02-10 11:28:26 +08:00
qhy
ff920b85a2 理论性能分析 2026-02-10 10:10:09 +08:00
qhy
6630952d2b 异步保存结果 2026-02-09 21:23:00 +08:00
qhy
bc78815acf 脚本参数暂时修改 2026-02-07 21:28:54 +08:00
qhy
d5f6577fa8 复制模型对象,跳过加载模型 2026-02-07 19:18:49 +08:00
qhy
7dcf9e8b89 VAE优化,模型直接加载至GPU 2026-02-07 17:36:00 +08:00
qhy
aba2a90045 算子融合 2026-02-07 16:40:33 +08:00
25de36b9bc 添加当前优化说明
相关参数改动和效果
2026-01-19 16:58:37 +08:00
2fdcec6da0 Delete README.md 2026-01-19 16:39:49 +08:00
9 changed files with 868 additions and 1412 deletions

245
README.md
View File

@@ -1,228 +1,29 @@
# UnifoLM-WMA-0: A World-Model-Action (WMA) Framework under UnifoLM Family
<p style="font-size: 1.2em;">
<a href="https://unigen-x.github.io/unifolm-world-model-action.github.io"><strong>Project Page</strong></a> |
<a href="https://huggingface.co/collections/unitreerobotics/unifolm-wma-0-68ca23027310c0ca0f34959c"><strong>Models</strong></a> |
<a href="https://huggingface.co/unitreerobotics/datasets"><strong>Dataset</strong></a>
</p>
<div align="center">
<p align="right">
<span> 🌎English </span> | <a href="README_cn.md"> 🇨🇳中文 </a>
</p>
</div>
<div align="justify">
<b>UnifoLM-WMA-0</b> is Unitrees open-source world-modelaction architecture spanning multiple types of robotic embodiments, designed specifically for general-purpose robot learning. Its core component is a world-model capable of understanding the physical interactions between robots and the environments. This world-model provides two key functions: (a) <b>Simulation Engine</b> operates as an interactive simulator to generate synthetic data for robot learning; (b) <b>Policy Enhancement</b> connects with an action head and, by predicting future interaction processes with the world-model, further optimizes decision-making performance.
</div>
# World Model Interaction 混合精度加速记录case1
## 🦾 Real-Robot Demonstrations
| <img src="assets/gifs/real_z1_stackbox.gif" style="border:none;box-shadow:none;margin:0;padding:0;" /> | <img src="assets/gifs/real_dual_stackbox.gif" style="border:none;box-shadow:none;margin:0;padding:0;" /> |
|:---:|:---:|
| <img src="assets/gifs/real_cleanup_pencils.gif" style="border:none;box-shadow:none;margin:0;padding:0;" /> | <img src="assets/gifs/real_g1_pack_camera.gif" style="border:none;box-shadow:none;margin:0;padding:0;" /> |
## 变更位置
- 脚本路径:`/home/dyz/unifolm-world-model-action/unitree_g1_pack_camera/case1/run_world_model_interaction.sh`
- 当前状态:已修改了部分原本不建议修改/需要谨慎修改的参数(后续会在确认最优后固化为默认)。
**Note: the top-right window shows the world models pretion of future action videos.**
## 新增参数(确认最优后可变为默认)
- `--diffusion_dtype {fp32,bf16}`Diffusion 权重与前向 dtype默认 `fp32`
- `--projector_mode {fp32,autocast,bf16_full}`Projector 精度策略,默认 `fp32`
- `--encoder_mode {fp32,autocast,bf16_full}`Encoder 精度策略,默认 `fp32`
- `--vae_dtype {fp32,bf16}`VAE 权重与前向 dtype默认 `fp32`
- `--export_casted_ckpt <path>`:按当前精度设置导出 ckpt用于离线导出混合精度权重
- `--export_only`:只导出 ckpt 后退出,默认关闭
## 🔥 News
### 参数语义约定
- `fp32`:权重 + 前向均使用 fp32
- `autocast`:权重保持 fp32forward 在 `torch.autocast` 下运行(算子级混精)
- `bf16_full`:权重显式转换为 bf16forward 也以 bf16 为主
* Sep 22, 2025: 🚀 We released the deployment code for assisting experiments with [Unitree](https://www.unitree.com/) robots.
* Sep 15, 2025: 🚀 We released the training and inference code along with the model weights of [**UnifoLM-WMA-0**](https://huggingface.co/collections/unitreerobotics/unifolm-wma-0-68ca23027310c0ca0f34959c).
## 当前最优配置与结果
### 配置
- 除 VAE 模块外,其它模块全部 bf16
- 模型离线导出混合精度 ckpt使用 `--export_casted_ckpt`
## 📑 Opensource Plan
- [x] Training
- [x] Inference
- [x] Checkpoints
- [x] Deployment
### 结果
- 耗时:从 `15m6s` 降到 `7m5s`
- PSNR下降不到 `4``35 -> 31`
- 显存:占用降到原本约 `50%`
## ⚙️ Installation
```
conda create -n unifolm-wma python==3.10.18
conda activate unifolm-wma
conda install pinocchio=3.2.0 -c conda-forge -y
conda install ffmpeg=7.1.1 -c conda-forge
git clone --recurse-submodules https://github.com/unitreerobotics/unifolm-world-model-action.git
# If you already downloaded the repo:
cd unifolm-world-model-action
git submodule update --init --recursive
pip install -e .
cd external/dlimp
pip install -e .
```
## 🧰 Model Checkpoints
| Model | Description | Link|
|---------|-------|------|
|$\text{UnifoLM-WMA-0}_{Base}$| Fine-tuned on [Open-X](https://robotics-transformer-x.github.io/) dataset. | [HuggingFace](https://huggingface.co/unitreerobotics/UnifoLM-WMA-0-Base)|
|$\text{UnifoLM-WMA-0}_{Dual}$| Fine-tuned on five [Unitree opensource dataset](https://huggingface.co/collections/unitreerobotics/g1-dex1-datasets-68bae98bf0a26d617f9983ab) in both decision-making and simulation modes. | [HuggingFace](https://huggingface.co/unitreerobotics/UnifoLM-WMA-0-Dual)|
## 🛢️ Dataset
In our experiments, we consider the following three opensource dataset:
| Dataset | Robot | Link |
|---------|-------|------|
|Z1_StackBox| [Unitree Z1](https://www.unitree.com/z1)|[Huggingface](https://huggingface.co/datasets/unitreerobotics/Z1_StackBox_Dataset/tree/v2.1)|
|Z1_DualArm_StackBox|[Unitree Z1](https://www.unitree.com/z1)|[Huggingface](https://huggingface.co/datasets/unitreerobotics/Z1_Dual_Dex1_StackBox_Dataset/tree/v2.1)|
|Z1_DualArm_StackBox_V2|[Unitree Z1](https://www.unitree.com/z1)|[Huggingface](https://huggingface.co/datasets/unitreerobotics/Z1_Dual_Dex1_StackBox_Dataset_V2/tree/v2.1)|
|Z1_DualArm_Cleanup_Pencils|[Unitree Z1](https://www.unitree.com/z1)|[Huggingface](https://huggingface.co/datasets/unitreerobotics/Z1_Dual_Dex1_CleanupPencils_Dataset/tree/v2.1)|
|G1_Pack_Camera|[Unitree G1](https://www.unitree.com/g1)|[Huggingface](https://huggingface.co/datasets/unitreerobotics/G1_Dex1_MountCameraRedGripper_Dataset/tree/v2.1)|
To train on your own dataset, first to have the data following the [Huggingface LeRobot V2.1](https://github.com/huggingface/lerobot) dataset format. Assume the datasets source directory structure is as follows:
```
source_dir/
├── dataset1_name
├── dataset2_name
├── dataset3_name
└── ...
```
Then, convert a dataset to the required format using the command below:
```python
cd prepare_data
python prepare_training_data.py \
--source_dir /path/to/your/source_dir \
--target_dir /path/to/save/the/converted/data \
--dataset_name "dataset1_name" \
--robot_name "a tag of the robot in the dataset" # e.g, Unitree Z1 Robot Arm or Unitree G1 Robot with Gripper.
```
The resulting data structure (Note: model training only supports input from the main-view camera. If the dataset includes multiple views, remove the corresponding values from the ```data_dir``` column in the CSV file.
```
target_dir/
├── videos
│ ├──dataset1_name
│ │ ├──camera_view_dir
│ │ ├── 0.mp4
│ │ ├── 1.mp4
│ │ └── ...
│ └── ...
├── transitions
│ ├── dataset1_name
│ ├── meta_data
│ ├── 0.h5
│ ├── 1.h5
│ └── ...
└── dataset1_name.csv
```
## 🚴‍♂️ Training
A. Our training strategy is outlined as follows:
- **Step 1**: Fine-tune a video generation model as the world model using the [Open-X](https://robotics-transformer-x.github.io/) dataset;
- **Step 2**: Post-train $\text{UnifoLM-WMA}$ in decision-making mode on the downstream task dataset;
<div align="left">
<img src="assets/pngs/dm_mode.png" width="600">
</div>
- **Step 3**: Post-train $\text{UnifoLM-WMA}$ in simulation mode on the downstream task dataset.
<div align="left">
<img src="assets/pngs/sim_mode.png" width="600">
</div>
**Note**: If you only require $\text{UnifoLM-WMA}$ to operate in a single mode, you may skip the corresponding step.
B. To conduct training on a single or multiple datasets, please follow the steps below:
- **Step 1**: The maximum DoF is assumed to be 16, if you have more than 16 DoF, update ```agent_state_dim``` and ```agent_action_dim``` in [configs/train/config.yaml](https://github.com/unitreerobotics/unifolm-wma/blob/working/configs/train/config.yaml) ;
- **Step 2**: Set up the input shapes for each modality in [configs/train/meta.json](https://github.com/unitreerobotics/unitree-world-model/blob/main/configs/train/meta.json);
- **Step 3**: Configure the training parameters in [configs/train/config.yaml](https://github.com/unitreerobotics/unitree-world-model/blob/main/configs/train/config.yaml). For the ```pretrained_checkpoint```, we recommend using the checkpoint " $\text{UnifoLM-WMA-0}_{Base}$ " fine-tuned on the [Open-X](https://robotics-transformer-x.github.io/) dataset;
```yaml
model:
pretrained_checkpoint: /path/to/pretrained/checkpoint;
...
decision_making_only: True # Train the world model only in decision-making mode. If False, jointly train it in both decision-making and simulation modes.
...
data:
...
train:
...
data_dir: /path/to/training/dataset/directory
dataset_and_weights: # list the name of each dataset below and make sure the summation of weights is 1.0
dataset1_name: 0.2
dataset2_name: 0.2
dataset3_name: 0.2
dataset4_name: 0.2
dataset5_name: 0.2
```
- **Step 4**: Setup ```experiment_name```, ```save_root``` variables in [scripts/train.sh](https://github.com/unitreerobotics/unitree-world-model/blob/main/scripts/train.sh);
- **Step 5**: Launch the training with the command:
```
bash scripts/train.sh
```
## 🌏 Inference under Interactive Simulation Mode
To run the world model in an interactive simulation mode, follow these steps:
- **Step 1**: (Skip this step if you just would like to test using the examples we provided) Prepare your own prompt following the format used in the [examples/world_model_interaction_prompts](https://github.com/unitreerobotics/unitree-world-model/tree/main/examples/world_model_interaction_prompts):
```
world_model_interaction_prompts/
├── images
│ ├── dataset1_name
│ │ ├── 0.png # Image prompt
│ │ └── ...
│ └── ...
├── transitions
│ ├── dataset1_name
│ │ ├── meta_data # Used for normalization
│ │ ├── 0.h # Robot state and action data; in interaction mode,
│ │ │ # only used to retrieve the robot state corresponding
│ │ │ # to the image prompt
│ │ └── ...
│ └── ...
├── dataset1_name.csv # File for loading image prompts, text instruction and corresponding robot states
└── ...
```
- **Step 2**: Specify the correct paths for ```pretrained_checkpoint```(e.g, $\text{UnifoLM-WMA-0}_{Dual}$) and ```data_dir``` in [configs/inference/world_model_interaction.yaml](https://github.com/unitreerobotics/unitree-world-model/blob/main/configs/inference/world_model_interaction.yaml)
- **Step 3**: Set the paths for ```checkpoint```, ```res_dir``` and ```prompt_dir``` in [scripts/run_world_model_interaction.sh](https://github.com/unitreerobotics/unitree-world-model/blob/main/scripts/run_world_model_interaction.sh), and specify all the dataset's name in ```datasets=(...)```. Then, launch the inference with the command:
```
bash scripts/run_world_model_interaction.sh
```
## 🧠 Inference and Deployment under Decision-Making Mode
In this setup, inference is performed on a server, while a robot client gathers observations from the real-robot and sends them to the server to query actions. The process unfolds through the following steps:
### Server Setup:
- **Step-1**: Specify ```ckpt```, ```res_dir```, ```datasets``` in [scripts/run_real_eval_server.sh](https://github.com/unitreerobotics/unifolm-world-model-action/blob/main/scripts/run_real_eval_server.sh);
- **Step-2**: Configure ```data_dir``` and ```dataset_and_weights``` in [config/inference/world_model_decision_making.yaml](https://github.com/unitreerobotics/unifolm-world-model-action/blob/f12b4782652ca00452941d851b17446e4ee7124a/configs/inference/world_model_decision_making.yaml#L225);
- **Step-3**: Launch the server:
```
conda activate unifolm-wma
cd unifolm-world-model-action
bash scripts/run_real_eval_server.sh
```
### Client Setup
- **Step-1**: Follow the instructions in [unitree_deploy/README.md](https://github.com/unitreerobotics/unifolm-world-model-action/blob/main/unitree_deploy/README.md) to create the ```unitree_deploy``` conda environment, install the required packages, launch the controllers or services on the real-robot.
- **Step-2**: Open a new terminal and establish a tunnel connection from the client to the server:
```
ssh user_name@remote_server_IP -CNg -L 8000:127.0.0.1:8000
```
- **Step-3**: Run the ```unitree_deploy/robot_client.py``` script to start inference:
```
cd unitree_deploy
python scripts/robot_client.py --robot_type "g1_dex1" --action_horizon 16 --exe_steps 16 --observation_horizon 2 --language_instruction "pack black camera into box" --output_dir ./results --control_freq 15
```
## 📝 Codebase Architecture
Here's a high-level overview of the project's code structure and core components:
```
unitree-world-model/
├── assets # Media assets such as GIFs, images, and demo videos
├── configs # Configuration files for training and inference
│ ├── inference
│ └── train
├── examples # Example inputs and prompts for running inference
├── external # External packages
├── prepare_data # Scripts for dataset preprocessing and format conversion
├── scripts # Main scripts for training, evaluation, and deployment
├── src
│ ├──unitree_worldmodel # Core Python package for the Unitree world model
│ │ ├── data # Dataset loading, transformations, and dataloaders
│ │ ├── models # Model architectures and backbone definitions
│ │ ├── modules # Custom model modules and components
│ │ └── utils # Utility functions and common helpers
└── unitree_deploy # Deployment code
```
## 🙏 Acknowledgement
Lots of code are inherited from [DynamiCrafter](https://github.com/Doubiiu/DynamiCrafter), [Diffusion Policy](https://github.com/real-stanford/diffusion_policy), [ACT](https://github.com/MarkFzp/act-plus-plus) and [HPT](https://github.com/liruiw/HPT).
## 📝 Citation
```
@misc{unifolm-wma-0,
author = {Unitree},
title = {UnifoLM-WMA-0: A World-Model-Action (WMA) Framework under UnifoLM Family},
year = {2025},
}
```

View File

@@ -222,7 +222,7 @@ data:
test:
target: unifolm_wma.data.wma_data.WMAData
params:
data_dir: '/home/dyz/unifolm-world-model-action/examples/world_model_interaction_prompts'
data_dir: '/home/qhy/unifolm-world-model-action/examples/world_model_interaction_prompts'
video_length: ${model.params.wma_config.params.temporal_length}
frame_stride: 2
load_raw_resolution: True

View File

@@ -9,10 +9,9 @@ import logging
import einops
import warnings
import imageio
import time
import json
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass, field, asdict
import atexit
from concurrent.futures import ThreadPoolExecutor
from contextlib import nullcontext
from typing import Optional, Dict, List, Any, Mapping
from pytorch_lightning import seed_everything
@@ -21,376 +20,66 @@ from tqdm import tqdm
from einops import rearrange, repeat
from collections import OrderedDict
from torch import nn
from eval_utils import populate_queues, log_to_tensorboard
from eval_utils import populate_queues
from collections import deque
from torch import Tensor
from torch.utils.tensorboard import SummaryWriter
from PIL import Image
from unifolm_wma.models.samplers.ddim import DDIMSampler
from unifolm_wma.utils.utils import instantiate_from_config
# ========== Profiling Infrastructure ==========
@dataclass
class TimingRecord:
"""Record for a single timing measurement."""
name: str
start_time: float = 0.0
end_time: float = 0.0
cuda_time_ms: float = 0.0
count: int = 0
children: List['TimingRecord'] = field(default_factory=list)
@property
def cpu_time_ms(self) -> float:
return (self.end_time - self.start_time) * 1000
def to_dict(self) -> dict:
return {
'name': self.name,
'cpu_time_ms': self.cpu_time_ms,
'cuda_time_ms': self.cuda_time_ms,
'count': self.count,
'children': [c.to_dict() for c in self.children]
}
# ========== Async I/O ==========
_io_executor: Optional[ThreadPoolExecutor] = None
_io_futures: List[Any] = []
class ProfilerManager:
"""Manages macro and micro-level profiling."""
def _get_io_executor() -> ThreadPoolExecutor:
global _io_executor
if _io_executor is None:
_io_executor = ThreadPoolExecutor(max_workers=2)
return _io_executor
def __init__(
self,
enabled: bool = False,
output_dir: str = "./profile_output",
profile_detail: str = "light",
):
self.enabled = enabled
self.output_dir = output_dir
self.profile_detail = profile_detail
self.macro_timings: Dict[str, List[float]] = {}
self.cuda_events: Dict[str, List[tuple]] = {}
self.memory_snapshots: List[Dict] = []
self.pytorch_profiler = None
self.current_iteration = 0
self.operator_stats: Dict[str, Dict] = {}
self.profiler_config = self._build_profiler_config(profile_detail)
if enabled:
os.makedirs(output_dir, exist_ok=True)
def _build_profiler_config(self, profile_detail: str) -> Dict[str, Any]:
"""Return profiler settings based on the requested detail level."""
if profile_detail not in ("light", "full"):
raise ValueError(f"Unsupported profile_detail: {profile_detail}")
if profile_detail == "full":
return {
"record_shapes": True,
"profile_memory": True,
"with_stack": True,
"with_flops": True,
"with_modules": True,
"group_by_input_shape": True,
}
return {
"record_shapes": False,
"profile_memory": False,
"with_stack": False,
"with_flops": False,
"with_modules": False,
"group_by_input_shape": False,
}
@contextmanager
def profile_section(self, name: str, sync_cuda: bool = True):
"""Context manager for profiling a code section."""
if not self.enabled:
yield
return
if sync_cuda and torch.cuda.is_available():
torch.cuda.synchronize()
start_event = None
end_event = None
if torch.cuda.is_available():
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
start_time = time.perf_counter()
def _flush_io():
"""Wait for all pending async I/O to finish."""
global _io_futures
for fut in _io_futures:
try:
yield
finally:
if sync_cuda and torch.cuda.is_available():
torch.cuda.synchronize()
end_time = time.perf_counter()
cpu_time_ms = (end_time - start_time) * 1000
cuda_time_ms = 0.0
if start_event is not None and end_event is not None:
end_event.record()
torch.cuda.synchronize()
cuda_time_ms = start_event.elapsed_time(end_event)
if name not in self.macro_timings:
self.macro_timings[name] = []
self.macro_timings[name].append(cpu_time_ms)
if name not in self.cuda_events:
self.cuda_events[name] = []
self.cuda_events[name].append((cpu_time_ms, cuda_time_ms))
def record_memory(self, tag: str = ""):
"""Record current GPU memory state."""
if not self.enabled or not torch.cuda.is_available():
return
snapshot = {
'tag': tag,
'iteration': self.current_iteration,
'allocated_mb': torch.cuda.memory_allocated() / 1024**2,
'reserved_mb': torch.cuda.memory_reserved() / 1024**2,
'max_allocated_mb': torch.cuda.max_memory_allocated() / 1024**2,
}
self.memory_snapshots.append(snapshot)
def start_pytorch_profiler(self, wait: int = 1, warmup: int = 1, active: int = 3):
"""Start PyTorch profiler for operator-level analysis."""
if not self.enabled:
return nullcontext()
self.pytorch_profiler = torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
schedule=torch.profiler.schedule(
wait=wait, warmup=warmup, active=active, repeat=1
),
on_trace_ready=self._trace_handler,
record_shapes=self.profiler_config["record_shapes"],
profile_memory=self.profiler_config["profile_memory"],
with_stack=self.profiler_config["with_stack"],
with_flops=self.profiler_config["with_flops"],
with_modules=self.profiler_config["with_modules"],
)
return self.pytorch_profiler
def _trace_handler(self, prof):
"""Handle profiler trace output."""
trace_path = os.path.join(
self.output_dir,
f"trace_iter_{self.current_iteration}.json"
)
prof.export_chrome_trace(trace_path)
# Extract operator statistics
key_averages = prof.key_averages(
group_by_input_shape=self.profiler_config["group_by_input_shape"]
)
for evt in key_averages:
op_name = evt.key
if op_name not in self.operator_stats:
self.operator_stats[op_name] = {
'count': 0,
'cpu_time_total_us': 0,
'cuda_time_total_us': 0,
'self_cpu_time_total_us': 0,
'self_cuda_time_total_us': 0,
'cpu_memory_usage': 0,
'cuda_memory_usage': 0,
'flops': 0,
}
stats = self.operator_stats[op_name]
stats['count'] += evt.count
stats['cpu_time_total_us'] += evt.cpu_time_total
stats['cuda_time_total_us'] += evt.cuda_time_total
stats['self_cpu_time_total_us'] += evt.self_cpu_time_total
stats['self_cuda_time_total_us'] += evt.self_cuda_time_total
if hasattr(evt, 'cpu_memory_usage'):
stats['cpu_memory_usage'] += evt.cpu_memory_usage
if hasattr(evt, 'cuda_memory_usage'):
stats['cuda_memory_usage'] += evt.cuda_memory_usage
if hasattr(evt, 'flops') and evt.flops:
stats['flops'] += evt.flops
def step_profiler(self):
"""Step the PyTorch profiler."""
if self.pytorch_profiler is not None:
self.pytorch_profiler.step()
def generate_report(self) -> str:
"""Generate comprehensive profiling report."""
if not self.enabled:
return "Profiling disabled."
report_lines = []
report_lines.append("=" * 80)
report_lines.append("PERFORMANCE PROFILING REPORT")
report_lines.append("=" * 80)
report_lines.append("")
# Macro-level timing summary
report_lines.append("-" * 40)
report_lines.append("MACRO-LEVEL TIMING SUMMARY")
report_lines.append("-" * 40)
report_lines.append(f"{'Section':<40} {'Count':>8} {'Total(ms)':>12} {'Avg(ms)':>12} {'CUDA Avg(ms)':>14}")
report_lines.append("-" * 86)
total_time = 0
timing_data = []
for name, times in sorted(self.macro_timings.items()):
cuda_times = [ct for _, ct in self.cuda_events.get(name, [])]
avg_time = np.mean(times)
avg_cuda = np.mean(cuda_times) if cuda_times else 0
total = sum(times)
total_time += total
timing_data.append({
'name': name,
'count': len(times),
'total_ms': total,
'avg_ms': avg_time,
'cuda_avg_ms': avg_cuda,
'times': times,
'cuda_times': cuda_times,
})
report_lines.append(f"{name:<40} {len(times):>8} {total:>12.2f} {avg_time:>12.2f} {avg_cuda:>14.2f}")
report_lines.append("-" * 86)
report_lines.append(f"{'TOTAL':<40} {'':<8} {total_time:>12.2f}")
report_lines.append("")
# Memory summary
if self.memory_snapshots:
report_lines.append("-" * 40)
report_lines.append("GPU MEMORY SUMMARY")
report_lines.append("-" * 40)
max_alloc = max(s['max_allocated_mb'] for s in self.memory_snapshots)
avg_alloc = np.mean([s['allocated_mb'] for s in self.memory_snapshots])
report_lines.append(f"Peak allocated: {max_alloc:>10.2f} MB")
report_lines.append(f"Average allocated: {avg_alloc:>10.2f} MB")
report_lines.append("")
# Top operators by CUDA time
if self.operator_stats:
report_lines.append("-" * 40)
report_lines.append("TOP 30 OPERATORS BY CUDA TIME")
report_lines.append("-" * 40)
sorted_ops = sorted(
self.operator_stats.items(),
key=lambda x: x[1]['cuda_time_total_us'],
reverse=True
)[:30]
report_lines.append(f"{'Operator':<50} {'Count':>8} {'CUDA(ms)':>12} {'CPU(ms)':>12} {'Self CUDA(ms)':>14}")
report_lines.append("-" * 96)
for op_name, stats in sorted_ops:
# Truncate long operator names
display_name = op_name[:47] + "..." if len(op_name) > 50 else op_name
report_lines.append(
f"{display_name:<50} {stats['count']:>8} "
f"{stats['cuda_time_total_us']/1000:>12.2f} "
f"{stats['cpu_time_total_us']/1000:>12.2f} "
f"{stats['self_cuda_time_total_us']/1000:>14.2f}"
)
report_lines.append("")
# Compute category breakdown
report_lines.append("-" * 40)
report_lines.append("OPERATOR CATEGORY BREAKDOWN")
report_lines.append("-" * 40)
categories = {
'Attention': ['attention', 'softmax', 'bmm', 'baddbmm'],
'Convolution': ['conv', 'cudnn'],
'Normalization': ['norm', 'layer_norm', 'batch_norm', 'group_norm'],
'Activation': ['relu', 'gelu', 'silu', 'sigmoid', 'tanh'],
'Linear/GEMM': ['linear', 'addmm', 'mm', 'gemm'],
'Memory': ['copy', 'contiguous', 'view', 'reshape', 'permute', 'transpose'],
'Elementwise': ['add', 'mul', 'div', 'sub', 'pow', 'exp', 'sqrt'],
}
category_times = {cat: 0.0 for cat in categories}
category_times['Other'] = 0.0
for op_name, stats in self.operator_stats.items():
op_lower = op_name.lower()
categorized = False
for cat, keywords in categories.items():
if any(kw in op_lower for kw in keywords):
category_times[cat] += stats['cuda_time_total_us']
categorized = True
break
if not categorized:
category_times['Other'] += stats['cuda_time_total_us']
total_op_time = sum(category_times.values())
report_lines.append(f"{'Category':<30} {'CUDA Time(ms)':>15} {'Percentage':>12}")
report_lines.append("-" * 57)
for cat, time_us in sorted(category_times.items(), key=lambda x: -x[1]):
pct = (time_us / total_op_time * 100) if total_op_time > 0 else 0
report_lines.append(f"{cat:<30} {time_us/1000:>15.2f} {pct:>11.1f}%")
report_lines.append("")
report = "\n".join(report_lines)
return report
def save_results(self):
"""Save all profiling results to files."""
if not self.enabled:
return
# Save report
report = self.generate_report()
report_path = os.path.join(self.output_dir, "profiling_report.txt")
with open(report_path, 'w') as f:
f.write(report)
print(f">>> Profiling report saved to: {report_path}")
# Save detailed JSON data
data = {
'macro_timings': {
name: {
'times': times,
'cuda_times': [ct for _, ct in self.cuda_events.get(name, [])]
}
for name, times in self.macro_timings.items()
},
'memory_snapshots': self.memory_snapshots,
'operator_stats': self.operator_stats,
}
json_path = os.path.join(self.output_dir, "profiling_data.json")
with open(json_path, 'w') as f:
json.dump(data, f, indent=2)
print(f">>> Detailed profiling data saved to: {json_path}")
# Print summary to console
print("\n" + report)
fut.result()
except Exception as e:
print(f">>> [async I/O] error: {e}")
_io_futures.clear()
# Global profiler instance
_profiler: Optional[ProfilerManager] = None
atexit.register(_flush_io)
def get_profiler() -> ProfilerManager:
"""Get the global profiler instance."""
global _profiler
if _profiler is None:
_profiler = ProfilerManager(enabled=False)
return _profiler
def init_profiler(enabled: bool, output_dir: str, profile_detail: str) -> ProfilerManager:
"""Initialize the global profiler."""
global _profiler
_profiler = ProfilerManager(
enabled=enabled,
output_dir=output_dir,
profile_detail=profile_detail,
)
return _profiler
def _save_results_sync(video_cpu: Tensor, filename: str, fps: int) -> None:
"""Synchronous save on CPU tensor (runs in background thread)."""
video = torch.clamp(video_cpu.float(), -1., 1.)
n = video.shape[0]
video = video.permute(2, 0, 1, 3, 4)
frame_grids = [
torchvision.utils.make_grid(framesheet, nrow=int(n), padding=0)
for framesheet in video
]
grid = torch.stack(frame_grids, dim=0)
grid = (grid + 1.0) / 2.0
grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1)
torchvision.io.write_video(filename,
grid,
fps=fps,
video_codec='h264',
options={'crf': '10'})
def save_results_async(video: Tensor, filename: str, fps: int = 8) -> None:
"""Submit video saving to background thread pool."""
video_cpu = video.detach().cpu()
fut = _get_io_executor().submit(_save_results_sync, video_cpu, filename, fps)
_io_futures.append(fut)
# ========== Original Functions ==========
@@ -458,7 +147,8 @@ def _load_state_dict(model: nn.Module,
def load_model_checkpoint(model: nn.Module,
ckpt: str,
assign: bool | None = None) -> nn.Module:
assign: bool | None = None,
device: str | torch.device = "cpu") -> nn.Module:
"""Load model weights from checkpoint file.
Args:
@@ -467,11 +157,12 @@ def load_model_checkpoint(model: nn.Module,
assign (bool | None): Whether to preserve checkpoint tensor dtypes
via load_state_dict(assign=True). If None, auto-enable when a
casted checkpoint metadata is detected.
device (str | torch.device): Target device for loaded tensors.
Returns:
nn.Module: Model with loaded weights.
"""
ckpt_data = torch.load(ckpt, map_location="cpu")
ckpt_data = torch.load(ckpt, map_location=device, mmap=True)
use_assign = False
if assign is not None:
use_assign = assign
@@ -511,9 +202,7 @@ def load_model_checkpoint(model: nn.Module,
def maybe_cast_module(module: nn.Module | None,
dtype: torch.dtype,
label: str,
profiler: Optional[ProfilerManager] = None,
profile_name: Optional[str] = None) -> None:
label: str) -> None:
if module is None:
return
try:
@@ -524,10 +213,6 @@ def maybe_cast_module(module: nn.Module | None,
if param.dtype == dtype:
print(f">>> {label} already {dtype}; skip cast")
return
ctx = nullcontext()
if profiler is not None and profile_name:
ctx = profiler.profile_section(profile_name)
with ctx:
module.to(dtype=dtype)
print(f">>> {label} cast to {dtype}")
@@ -746,8 +431,6 @@ def get_latent_z(model, videos: Tensor) -> Tensor:
Returns:
Tensor: Latent video tensor of shape [B, C, T, H, W].
"""
profiler = get_profiler()
with profiler.profile_section("get_latent_z/encode"):
b, c, t, h, w = videos.shape
x = rearrange(videos, 'b c t h w -> (b t) c h w')
vae_ctx = nullcontext()
@@ -862,8 +545,6 @@ def image_guided_synthesis_sim_mode(
actions (torch.Tensor): Predicted action sequences [B, T, D] from diffusion decoding.
states (torch.Tensor): Predicted state sequences [B, T, D] from diffusion decoding.
"""
profiler = get_profiler()
b, _, t, _, _ = noise_shape
ddim_sampler = getattr(model, "_ddim_sampler", None)
if ddim_sampler is None:
@@ -873,7 +554,6 @@ def image_guided_synthesis_sim_mode(
fs = torch.tensor([fs] * batch_size, dtype=torch.long, device=model.device)
with profiler.profile_section("synthesis/conditioning_prep"):
img = observation['observation.images.top'].permute(0, 2, 1, 3, 4)
cond_img = rearrange(img, 'b o c h w -> (b o) c h w')[-1:]
if getattr(model, "encoder_bf16", False) and model.device.type == "cuda":
@@ -959,7 +639,6 @@ def image_guided_synthesis_sim_mode(
cond_z0 = None
if ddim_sampler is not None:
with profiler.profile_section("synthesis/ddim_sampling"):
autocast_ctx = nullcontext()
if diffusion_autocast_dtype is not None and model.device.type == "cuda":
autocast_ctx = torch.autocast("cuda", dtype=diffusion_autocast_dtype)
@@ -982,7 +661,6 @@ def image_guided_synthesis_sim_mode(
**kwargs)
# Reconstruct from latent to pixel space
with profiler.profile_section("synthesis/decode_first_stage"):
if getattr(model, "vae_bf16", False):
if samples.dtype != torch.bfloat16:
samples = samples.to(dtype=torch.bfloat16)
@@ -1012,21 +690,31 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
Returns:
None
"""
profiler = get_profiler()
# Create inference and tensorboard dirs
# Create inference dir
os.makedirs(args.savedir + '/inference', exist_ok=True)
log_dir = args.savedir + f"/tensorboard"
os.makedirs(log_dir, exist_ok=True)
writer = SummaryWriter(log_dir=log_dir)
# Load prompt
csv_path = os.path.join(args.prompt_dir, f"{args.dataset}.csv")
df = pd.read_csv(csv_path)
# Load config
with profiler.profile_section("model_loading/config"):
# Load config (always needed for data setup)
config = OmegaConf.load(args.config)
prepared_path = args.ckpt_path + ".prepared.pt"
if os.path.exists(prepared_path):
# ---- Fast path: load the fully-prepared model ----
print(f">>> Loading prepared model from {prepared_path} ...")
model = torch.load(prepared_path,
map_location=f"cuda:{gpu_no}",
weights_only=False,
mmap=True)
model.eval()
diffusion_autocast_dtype = (torch.bfloat16
if args.diffusion_dtype == "bf16"
else None)
print(f">>> Prepared model loaded.")
else:
# ---- Normal path: construct + checkpoint + casting ----
config['model']['params']['wma_config']['params'][
'use_checkpoint'] = False
model = instantiate_from_config(config.model)
@@ -1034,30 +722,18 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
assert os.path.exists(args.ckpt_path), "Error: checkpoint Not Found!"
with profiler.profile_section("model_loading/checkpoint"):
model = load_model_checkpoint(model, args.ckpt_path)
model = load_model_checkpoint(model, args.ckpt_path,
device=f"cuda:{gpu_no}")
model.eval()
model = model.cuda(gpu_no) # move residual buffers not in state_dict
print(f'>>> Load pre-trained model ...')
# Build unnomalizer
logging.info("***** Configing Data *****")
with profiler.profile_section("data_loading"):
data = instantiate_from_config(config.data)
data.setup()
print(">>> Dataset is successfully loaded ...")
with profiler.profile_section("model_to_cuda"):
model = model.cuda(gpu_no)
device = get_device_from_parameters(model)
diffusion_autocast_dtype = None
if args.diffusion_dtype == "bf16":
maybe_cast_module(
model.model,
torch.bfloat16,
"diffusion backbone",
profiler=profiler,
profile_name="model_loading/diffusion_bf16",
)
diffusion_autocast_dtype = torch.bfloat16
print(">>> diffusion backbone set to bfloat16")
@@ -1068,12 +744,29 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
model.first_stage_model,
vae_weight_dtype,
"VAE",
profiler=profiler,
profile_name="model_loading/vae_cast",
)
model.vae_bf16 = args.vae_dtype == "bf16"
print(f">>> VAE dtype set to {args.vae_dtype}")
# --- VAE performance optimizations ---
if hasattr(model, "first_stage_model") and model.first_stage_model is not None:
vae = model.first_stage_model
# torch.compile: fuses GroupNorm+SiLU, conv chains, etc.
if args.vae_compile:
vae.decoder = torch.compile(vae.decoder, mode="reduce-overhead")
vae.encoder = torch.compile(vae.encoder, mode="reduce-overhead")
print(">>> VAE encoder/decoder compiled with torch.compile (reduce-overhead)")
# Batch decode size
vae_decode_bs = args.vae_decode_bs if args.vae_decode_bs > 0 else 9999
model.vae_decode_bs = vae_decode_bs
model.vae_encode_bs = vae_decode_bs
if args.vae_decode_bs > 0:
print(f">>> VAE encode/decode batch size set to {args.vae_decode_bs}")
else:
print(">>> VAE encode/decode batch size: all frames at once")
encoder_mode = args.encoder_mode
encoder_bf16 = encoder_mode in ("autocast", "bf16_full")
encoder_weight_dtype = torch.bfloat16 if encoder_mode == "bf16_full" else torch.float32
@@ -1082,16 +775,12 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
model.cond_stage_model,
encoder_weight_dtype,
"cond_stage_model",
profiler=profiler,
profile_name="model_loading/encoder_cond_cast",
)
if hasattr(model, "embedder") and model.embedder is not None:
maybe_cast_module(
model.embedder,
encoder_weight_dtype,
"embedder",
profiler=profiler,
profile_name="model_loading/encoder_embedder_cast",
)
model.encoder_bf16 = encoder_bf16
model.encoder_mode = encoder_mode
@@ -1107,24 +796,18 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
model.image_proj_model,
projector_weight_dtype,
"image_proj_model",
profiler=profiler,
profile_name="model_loading/projector_image_cast",
)
if hasattr(model, "state_projector") and model.state_projector is not None:
maybe_cast_module(
model.state_projector,
projector_weight_dtype,
"state_projector",
profiler=profiler,
profile_name="model_loading/projector_state_cast",
)
if hasattr(model, "action_projector") and model.action_projector is not None:
maybe_cast_module(
model.action_projector,
projector_weight_dtype,
"action_projector",
profiler=profiler,
profile_name="model_loading/projector_action_cast",
)
if hasattr(model, "projector_bf16"):
model.projector_bf16 = projector_bf16
@@ -1148,7 +831,18 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
print(">>> export_only set; skipping inference.")
return
profiler.record_memory("after_model_load")
# Save prepared model for fast loading next time
if prepared_path:
print(f">>> Saving prepared model to {prepared_path} ...")
torch.save(model, prepared_path)
print(f">>> Prepared model saved ({os.path.getsize(prepared_path) / 1024**3:.1f} GB).")
# Build normalizer (always needed, independent of model loading path)
logging.info("***** Configing Data *****")
data = instantiate_from_config(config.data)
data.setup()
print(">>> Dataset is successfully loaded ...")
device = get_device_from_parameters(model)
# Run over data
assert (args.height % 16 == 0) and (
@@ -1163,10 +857,6 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
print(f'>>> Generate {n_frames} frames under each generation ...')
noise_shape = [args.bs, channels, n_frames, h, w]
# Determine profiler iterations
profile_active_iters = getattr(args, 'profile_iterations', 3)
use_pytorch_profiler = profiler.enabled and profile_active_iters > 0
# Start inference
for idx in range(0, len(df)):
sample = df.iloc[idx]
@@ -1182,7 +872,6 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
# Load transitions to get the initial state later
transition_path = get_transition_path(args.prompt_dir, sample)
with profiler.profile_section("load_transitions"):
with h5py.File(transition_path, 'r') as h5f:
transition_dict = {}
for key in h5f.keys():
@@ -1210,7 +899,6 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
}
# Obtain initial frame and state
with profiler.profile_section("prepare_init_input"):
start_idx = 0
model_input_fs = ori_fps // fs
batch, ori_state_dim, ori_action_dim = prepare_init_input(
@@ -1233,24 +921,12 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
# Update observation queues
cond_obs_queues = populate_queues(cond_obs_queues, observation)
# Setup PyTorch profiler context if enabled
pytorch_prof_ctx = nullcontext()
if use_pytorch_profiler:
pytorch_prof_ctx = profiler.start_pytorch_profiler(
wait=1, warmup=1, active=profile_active_iters
)
# Multi-round interaction with the world-model
with pytorch_prof_ctx:
for itr in tqdm(range(args.n_iter)):
log_every = max(1, args.step_log_every)
log_step = (itr % log_every == 0)
profiler.current_iteration = itr
profiler.record_memory(f"iter_{itr}_start")
with profiler.profile_section("iteration_total"):
# Get observation
with profiler.profile_section("prepare_observation"):
observation = {
'observation.images.top':
torch.stack(list(
@@ -1267,7 +943,6 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
# Use world-model in policy to generate action
if log_step:
print(f'>>> Step {itr}: generating actions ...')
with profiler.profile_section("action_generation"):
pred_videos_0, pred_actions, _ = image_guided_synthesis_sim_mode(
model,
sample['instruction'],
@@ -1285,7 +960,6 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
diffusion_autocast_dtype=diffusion_autocast_dtype)
# Update future actions in the observation queues
with profiler.profile_section("update_action_queues"):
for act_idx in range(len(pred_actions[0])):
obs_update = {'action': pred_actions[0][act_idx:act_idx + 1]}
obs_update['action'][:, ori_action_dim:] = 0.0
@@ -1293,7 +967,6 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
obs_update)
# Collect data for interacting the world-model using the predicted actions
with profiler.profile_section("prepare_wm_observation"):
observation = {
'observation.images.top':
torch.stack(list(
@@ -1310,7 +983,6 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
# Interaction with the world-model
if log_step:
print(f'>>> Step {itr}: interacting with world model ...')
with profiler.profile_section("world_model_interaction"):
pred_videos_1, _, pred_states = image_guided_synthesis_sim_mode(
model,
"",
@@ -1327,7 +999,6 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
guidance_rescale=args.guidance_rescale,
diffusion_autocast_dtype=diffusion_autocast_dtype)
with profiler.profile_section("update_state_queues"):
for step_idx in range(args.exe_steps):
obs_update = {
'observation.images.top':
@@ -1342,28 +1013,14 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
cond_obs_queues = populate_queues(cond_obs_queues,
obs_update)
# Save the imagen videos for decision-making
with profiler.profile_section("save_results"):
sample_tag = f"{args.dataset}-vid{sample['videoid']}-dm-fs-{fs}/itr-{itr}"
log_to_tensorboard(writer,
pred_videos_0,
sample_tag,
fps=args.save_fps)
# Save videos environment changes via world-model interaction
sample_tag = f"{args.dataset}-vid{sample['videoid']}-wd-fs-{fs}/itr-{itr}"
log_to_tensorboard(writer,
pred_videos_1,
sample_tag,
fps=args.save_fps)
# Save the imagen videos for decision-making
# Save the imagen videos for decision-making (async)
sample_video_file = f'{video_save_dir}/dm/{fs}/itr-{itr}.mp4'
save_results(pred_videos_0.cpu(),
save_results_async(pred_videos_0,
sample_video_file,
fps=args.save_fps)
# Save videos environment changes via world-model interaction
sample_video_file = f'{video_save_dir}/wm/{fs}/itr-{itr}.mp4'
save_results(pred_videos_1.cpu(),
save_results_async(pred_videos_1,
sample_video_file,
fps=args.save_fps)
@@ -1371,20 +1028,12 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
# Collect the result of world-model interactions
wm_video.append(pred_videos_1[:, :, :args.exe_steps].cpu())
profiler.record_memory(f"iter_{itr}_end")
profiler.step_profiler()
full_video = torch.cat(wm_video, dim=2)
sample_tag = f"{args.dataset}-vid{sample['videoid']}-wd-fs-{fs}/full"
log_to_tensorboard(writer,
full_video,
sample_tag,
fps=args.save_fps)
sample_full_video_file = f"{video_save_dir}/../{sample['videoid']}_full_fs{fs}.mp4"
save_results(full_video, sample_full_video_file, fps=args.save_fps)
save_results_async(full_video, sample_full_video_file, fps=args.save_fps)
# Save profiling results
profiler.save_results()
# Wait for all async I/O to complete
_flush_io()
def get_parser():
@@ -1511,6 +1160,18 @@ def get_parser():
default="fp32",
help="Dtype for VAE/first_stage_model weights and forward autocast."
)
parser.add_argument(
"--vae_compile",
action='store_true',
default=False,
help="Apply torch.compile to VAE decoder for kernel fusion."
)
parser.add_argument(
"--vae_decode_bs",
type=int,
default=0,
help="VAE decode batch size (0=all frames at once). Reduces kernel launch overhead."
)
parser.add_argument(
"--export_casted_ckpt",
type=str,
@@ -1556,32 +1217,6 @@ def get_parser():
type=int,
default=8,
help="fps for the saving video")
# Profiling arguments
parser.add_argument(
"--profile",
action='store_true',
default=False,
help="Enable performance profiling (macro and operator-level analysis)."
)
parser.add_argument(
"--profile_output_dir",
type=str,
default=None,
help="Directory to save profiling results. Defaults to {savedir}/profile_output."
)
parser.add_argument(
"--profile_iterations",
type=int,
default=3,
help="Number of iterations to run PyTorch profiler's active phase for operator-level analysis."
)
parser.add_argument(
"--profile_detail",
type=str,
choices=["light", "full"],
default="light",
help="Profiling detail level. Use 'full' for shapes/stacks/memory/flops."
)
return parser
@@ -1593,15 +1228,5 @@ if __name__ == '__main__':
seed = random.randint(0, 2**31)
seed_everything(seed)
# Initialize profiler
profile_output_dir = args.profile_output_dir
if profile_output_dir is None:
profile_output_dir = os.path.join(args.savedir, "profile_output")
init_profiler(
enabled=args.profile,
output_dir=profile_output_dir,
profile_detail=args.profile_detail,
)
rank, gpu_num = 0, 1
run_inference(args, gpu_num, rank)

View File

@@ -99,7 +99,6 @@ class AutoencoderKL(pl.LightningModule):
print(f"Restored from {path}")
def encode(self, x, **kwargs):
h = self.encoder(x)
moments = self.quant_conv(h)
posterior = DiagonalGaussianDistribution(moments)

View File

@@ -1073,11 +1073,15 @@ class LatentDiffusion(DDPM):
if not self.perframe_ae:
encoder_posterior = self.first_stage_model.encode(x)
results = self.get_first_stage_encoding(encoder_posterior).detach()
else: ## Consume less GPU memory but slower
else: ## Batch encode with configurable batch size
bs = getattr(self, 'vae_encode_bs', 1)
if bs >= x.shape[0]:
encoder_posterior = self.first_stage_model.encode(x)
results = self.get_first_stage_encoding(encoder_posterior).detach()
else:
results = []
for index in range(x.shape[0]):
frame_batch = self.first_stage_model.encode(x[index:index +
1, :, :, :])
for i in range(0, x.shape[0], bs):
frame_batch = self.first_stage_model.encode(x[i:i + bs])
frame_result = self.get_first_stage_encoding(
frame_batch).detach()
results.append(frame_result)
@@ -1105,15 +1109,20 @@ class LatentDiffusion(DDPM):
else:
reshape_back = False
if not self.perframe_ae:
z = 1. / self.scale_factor * z
if not self.perframe_ae:
results = self.first_stage_model.decode(z, **kwargs)
else:
bs = getattr(self, 'vae_decode_bs', 1)
if bs >= z.shape[0]:
# all frames in one batch
results = self.first_stage_model.decode(z, **kwargs)
else:
results = []
for index in range(z.shape[0]):
frame_z = 1. / self.scale_factor * z[index:index + 1, :, :, :]
frame_result = self.first_stage_model.decode(frame_z, **kwargs)
results.append(frame_result)
for i in range(0, z.shape[0], bs):
results.append(
self.first_stage_model.decode(z[i:i + bs], **kwargs))
results = torch.cat(results, dim=0)
if reshape_back:

View File

@@ -55,16 +55,13 @@ class DDIMSampler(object):
to_torch(self.model.alphas_cumprod_prev))
# Calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer('sqrt_alphas_cumprod',
to_torch(np.sqrt(alphas_cumprod.cpu())))
self.register_buffer('sqrt_one_minus_alphas_cumprod',
to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
self.register_buffer('log_one_minus_alphas_cumprod',
to_torch(np.log(1. - alphas_cumprod.cpu())))
self.register_buffer('sqrt_recip_alphas_cumprod',
to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
self.register_buffer('sqrt_recipm1_alphas_cumprod',
to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
# Computed directly on GPU to avoid CPU↔GPU transfers
ac = to_torch(alphas_cumprod)
self.register_buffer('sqrt_alphas_cumprod', ac.sqrt())
self.register_buffer('sqrt_one_minus_alphas_cumprod', (1. - ac).sqrt())
self.register_buffer('log_one_minus_alphas_cumprod', (1. - ac).log())
self.register_buffer('sqrt_recip_alphas_cumprod', ac.rsqrt())
self.register_buffer('sqrt_recipm1_alphas_cumprod', (1. / ac - 1).sqrt())
# DDIM sampling parameters
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
@@ -86,6 +83,11 @@ class DDIMSampler(object):
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
self.register_buffer('ddim_sqrt_one_minus_alphas',
torch.sqrt(1. - ddim_alphas))
# Precomputed coefficients for DDIM update formula
self.register_buffer('ddim_sqrt_alphas', ddim_alphas.sqrt())
self.register_buffer('ddim_sqrt_alphas_prev', ddim_alphas_prev.sqrt())
self.register_buffer('ddim_dir_coeff',
(1. - ddim_alphas_prev - ddim_sigmas**2).sqrt())
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) *
(1 - self.alphas_cumprod / self.alphas_cumprod_prev))
@@ -208,18 +210,11 @@ class DDIMSampler(object):
dp_ddim_scheduler_state = self.model.dp_noise_scheduler_state
b = shape[0]
if x_T is None:
img = torch.randn(shape, device=device)
action = torch.randn((b, 16, self.model.agent_action_dim),
device=device)
state = torch.randn((b, 16, self.model.agent_state_dim),
device=device)
else:
img = x_T
action = torch.randn((b, 16, self.model.agent_action_dim),
device=device)
state = torch.randn((b, 16, self.model.agent_state_dim),
device=device)
img = torch.randn(shape, device=device) if x_T is None else x_T
if precision is not None:
if precision == 16:
@@ -362,12 +357,13 @@ class DDIMSampler(object):
**kwargs)
else:
raise NotImplementedError
model_output = e_t_uncond + unconditional_guidance_scale * (
e_t_cond - e_t_uncond)
model_output_action = e_t_uncond_action + unconditional_guidance_scale * (
e_t_cond_action - e_t_uncond_action)
model_output_state = e_t_uncond_state + unconditional_guidance_scale * (
e_t_cond_state - e_t_uncond_state)
model_output = torch.lerp(e_t_uncond, e_t_cond,
unconditional_guidance_scale)
model_output_action = torch.lerp(e_t_uncond_action,
e_t_cond_action,
unconditional_guidance_scale)
model_output_state = torch.lerp(e_t_uncond_state, e_t_cond_state,
unconditional_guidance_scale)
if guidance_rescale > 0.0:
model_output = rescale_noise_cfg(
@@ -396,18 +392,28 @@ class DDIMSampler(object):
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
sigmas = self.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
if use_original_steps:
sqrt_alphas = alphas.sqrt()
sqrt_alphas_prev = alphas_prev.sqrt()
dir_coeffs = (1. - alphas_prev - sigmas**2).sqrt()
else:
sqrt_alphas = self.ddim_sqrt_alphas
sqrt_alphas_prev = self.ddim_sqrt_alphas_prev
dir_coeffs = self.ddim_dir_coeff
if is_video:
size = (1, 1, 1, 1, 1)
else:
size = (1, 1, 1, 1)
a_t = alphas[index].view(size)
a_prev = alphas_prev[index].view(size)
sqrt_at = sqrt_alphas[index].view(size)
sqrt_a_prev = sqrt_alphas_prev[index].view(size)
sigma_t = sigmas[index].view(size)
sqrt_one_minus_at = sqrt_one_minus_alphas[index].view(size)
dir_coeff = dir_coeffs[index].view(size)
if self.model.parameterization != "v":
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
pred_x0 = (x - sqrt_one_minus_at * e_t) / sqrt_at
else:
pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
@@ -420,14 +426,11 @@ class DDIMSampler(object):
if quantize_denoised:
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
noise = sigma_t * noise_like(x.shape, device,
repeat_noise) * temperature
if noise_dropout > 0.:
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
x_prev = sqrt_a_prev * pred_x0 + dir_coeff * e_t + noise
return x_prev, pred_x0, model_output_action, model_output_state
@@ -475,7 +478,7 @@ class DDIMSampler(object):
sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
else:
sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
sqrt_alphas_cumprod = self.ddim_sqrt_alphas
sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
if noise is None:

View File

@@ -10,8 +10,8 @@ from unifolm_wma.utils.utils import instantiate_from_config
def nonlinearity(x):
# swish
return x * torch.sigmoid(x)
# swish / SiLU — single fused CUDA kernel instead of x * sigmoid(x)
return torch.nn.functional.silu(x)
def Normalize(in_channels, num_groups=32):

View File

@@ -7,79 +7,97 @@ MACRO-LEVEL TIMING SUMMARY
----------------------------------------
Section Count Total(ms) Avg(ms) CUDA Avg(ms)
--------------------------------------------------------------------------------------
action_generation 11 399707.47 36337.04 36336.85
data_loading 1 52.85 52.85 52.88
get_latent_z/encode 22 901.39 40.97 41.01
iteration_total 11 836793.23 76072.11 76071.63
load_transitions 1 2.24 2.24 2.28
model_loading/checkpoint 1 11833.31 11833.31 11833.43
model_loading/config 1 49774.19 49774.19 49774.16
model_to_cuda 1 8909.30 8909.30 8909.33
prepare_init_input 1 10.52 10.52 10.55
prepare_observation 11 5.41 0.49 0.53
prepare_wm_observation 11 2.12 0.19 0.22
save_results 11 38668.06 3515.28 3515.32
synthesis/conditioning_prep 22 2916.63 132.57 132.61
synthesis/ddim_sampling 22 782695.01 35577.05 35576.86
synthesis/decode_first_stage 22 12444.31 565.65 565.70
update_action_queues 11 6.85 0.62 0.65
update_state_queues 11 17.67 1.61 1.64
world_model_interaction 11 398375.58 36215.96 36215.75
action_generation 11 173133.54 15739.41 15739.36
data_loading 1 54.31 54.31 54.34
get_latent_z/encode 22 785.25 35.69 35.72
iteration_total 11 386482.08 35134.73 35134.55
load_transitions 1 2.07 2.07 2.10
model_loading/prepared 1 4749.22 4749.22 4749.83
prepare_init_input 1 29.19 29.19 29.22
prepare_observation 11 5.49 0.50 0.53
prepare_wm_observation 11 1.93 0.18 0.20
save_results 11 38791.18 3526.47 3526.51
synthesis/conditioning_prep 22 2528.23 114.92 114.95
synthesis/ddim_sampling 22 336003.29 15272.88 15272.83
synthesis/decode_first_stage 22 9095.14 413.42 413.46
update_action_queues 11 7.28 0.66 0.69
update_state_queues 11 17.38 1.58 1.61
world_model_interaction 11 174516.52 15865.14 15865.07
--------------------------------------------------------------------------------------
TOTAL 2543116.13
TOTAL 1126202.08
----------------------------------------
GPU MEMORY SUMMARY
----------------------------------------
Peak allocated: 17890.50 MB
Average allocated: 16129.98 MB
Peak allocated: 18188.29 MB
Average allocated: 9117.49 MB
----------------------------------------
TOP 30 OPERATORS BY CUDA TIME
----------------------------------------
Operator Count CUDA(ms) CPU(ms) Self CUDA(ms)
------------------------------------------------------------------------------------------------
ProfilerStep* 6 443804.16 237696.98 237689.25
aten::linear 171276 112286.23 13179.82 0.00
aten::addmm 81456 79537.36 3799.84 79296.37
ampere_sgemm_128x64_tn 26400 52052.10 0.00 52052.10
aten::matmul 90468 34234.05 6281.32 0.00
aten::_convolution 100242 33623.79 13105.89 0.00
aten::mm 89820 33580.74 3202.22 33253.18
aten::convolution 100242 33575.23 13714.47 0.00
aten::cudnn_convolution 98430 30932.19 8640.50 29248.12
ampere_sgemm_32x128_tn 42348 20394.52 0.00 20394.52
aten::conv2d 42042 18115.35 5932.30 0.00
ampere_sgemm_128x32_tn 40938 16429.81 0.00 16429.81
xformers::efficient_attention_forward_cutlass 24000 15222.23 2532.93 15120.44
fmha_cutlassF_f32_aligned_64x64_rf_sm80(Attenti... 24000 15121.31 0.00 15121.31
ampere_sgemm_64x64_tn 21000 14627.12 0.00 14627.12
aten::copy_ 231819 14504.87 127056.51 14038.39
aten::group_norm 87144 12033.73 10659.57 0.00
aten::native_group_norm 87144 11473.40 9449.36 11002.02
aten::conv3d 26400 8852.13 3365.43 0.00
void at::native::(anonymous namespace)::Rowwise... 87144 8714.68 0.00 8714.68
void cudnn::ops::nchwToNhwcKernel<float, float,... 169824 8525.44 0.00 8525.44
aten::clone 214314 8200.26 8568.82 0.00
void at::native::elementwise_kernel<128, 2, at:... 220440 8109.62 0.00 8109.62
void cutlass::Kernel<cutlass_80_simt_sgemm_128x... 15000 7919.30 0.00 7919.30
aten::_to_copy 12219 5963.43 122411.53 0.00
aten::to 58101 5952.65 122443.72 0.00
aten::conv1d 30000 5878.95 4556.48 0.00
Memcpy HtoD (Pageable -> Device) 6696 5856.39 0.00 5856.39
aten::reshape 671772 5124.03 9636.01 0.00
sm80_xmma_fprop_implicit_gemm_indexed_tf32f32_t... 16272 5097.70 0.00 5097.70
ProfilerStep* 18 690146.23 133688.74 616385.44
aten::group_norm 168624 24697.84 29217.27 0.00
aten::_convolution 96450 21420.26 12845.86 0.00
aten::convolution 96450 21408.68 13480.97 0.00
aten::linear 297398 20780.15 26257.38 0.00
aten::cudnn_convolution 94638 18660.24 8239.04 18329.28
aten::copy_ 772677 18135.46 17387.09 17864.87
aten::conv3d 52800 12922.42 8572.58 0.00
aten::conv2d 52469 12747.13 7725.70 0.00
aten::native_group_norm 84312 10285.37 8974.31 10197.66
aten::_to_copy 590277 10270.09 22570.90 0.00
aten::to 602979 9655.26 23666.06 0.00
aten::conv1d 56245 8174.37 10015.24 0.00
void at::native::(anonymous namespace)::Rowwise... 84312 7979.71 0.00 7979.71
aten::clone 177132 7502.90 7007.48 0.00
void cudnn::ops::nchwToNhwcKernel<__nv_bfloat16... 164700 7384.52 0.00 7384.52
aten::addmm 81456 6958.44 3903.01 6908.44
aten::layer_norm 65700 5698.92 7816.08 0.00
void at::native::elementwise_kernel<128, 4, at:... 149688 5372.46 0.00 5372.46
void at::native::unrolled_elementwise_kernel<at... 180120 5165.28 0.00 5165.28
ampere_bf16_s16816gemm_bf16_128x128_ldg8_relu_f... 24900 4449.05 0.00 4449.05
void at::native::unrolled_elementwise_kernel<at... 368664 4405.30 0.00 4405.30
aten::reshape 686778 3771.84 8309.51 0.00
aten::contiguous 46008 3400.88 1881.73 0.00
sm80_xmma_fprop_implicit_gemm_bf16bf16_bf16f32_... 15516 3398.03 0.00 3398.03
aten::matmul 90489 3366.62 4946.69 0.00
aten::mm 89820 3284.53 3308.76 3228.56
void at::native::elementwise_kernel<128, 2, at:... 46518 2441.55 0.00 2441.55
aten::add 113118 2426.66 2776.23 2385.52
void at::native::elementwise_kernel<128, 4, at:... 104550 2426.41 0.00 2426.41
----------------------------------------
OPERATOR CATEGORY BREAKDOWN
----------------------------------------
Category CUDA Time(ms) Percentage
---------------------------------------------------------
Other 481950.47 41.9%
Linear/GEMM 342333.09 29.8%
Convolution 159920.77 13.9%
Elementwise 54682.93 4.8%
Memory 36883.36 3.2%
Attention 34736.13 3.0%
Normalization 32081.19 2.8%
Activation 6449.19 0.6%
Other 723472.91 71.9%
Convolution 114469.81 11.4%
Memory 53845.46 5.4%
Normalization 46852.57 4.7%
Linear/GEMM 35354.58 3.5%
Elementwise 17078.44 1.7%
Activation 12296.29 1.2%
Attention 2956.61 0.3%
------------------------------------------------------------------------------------------
aten::addmm (Linear/GEMM) UTILIZATION ANALYSIS ON A100
------------------------------------------------------------------------------------------
Effective compute precision: BF16 Tensor Core (312 TFLOPS)
torch.backends.cuda.matmul.allow_tf32 = False
Metric Value
-------------------------------------------------------------------
Total aten::addmm calls 81,456
Total Self CUDA time 6908.44 ms
Total FLOPs (profiler) 1.33 PFLOPS
Achieved throughput 191.88 TFLOPS/s
A100 peak throughput 312.00 TFLOPS/s
MFU (Model FLOPs Utilization) 61.50%
INTERPRETATION:
-------------------------------------------------------------------
Good utilization (>60%). GEMM kernels are compute-bound
and running efficiently on Tensor Cores.

View File

@@ -2,7 +2,7 @@ res_dir="unitree_g1_pack_camera/case1"
dataset="unitree_g1_pack_camera"
{
time CUDA_VISIBLE_DEVICES=1 python3 scripts/evaluation/world_model_interaction.py \
time CUDA_VISIBLE_DEVICES=0 python3 scripts/evaluation/world_model_interaction.py \
--seed 123 \
--ckpt_path ckpts/unifolm_wma_dual_mixbf16.ckpt \
--config configs/inference/world_model_interaction.yaml \
@@ -23,5 +23,6 @@ dataset="unitree_g1_pack_camera"
--perframe_ae \
--diffusion_dtype bf16 \
--projector_mode bf16_full \
--encoder_mode bf16_full
--encoder_mode bf16_full \
--vae_dtype bf16
} 2>&1 | tee "${res_dir}/output.log"