10 Commits

Author SHA1 Message Date
qhy
bb274870c2 整理代码 2026-02-10 12:46:12 +08:00
qhy
f1f92072e6 remove profile 2026-02-10 11:28:26 +08:00
qhy
ff920b85a2 理论性能分析 2026-02-10 10:10:09 +08:00
qhy
6630952d2b 异步保存结果 2026-02-09 21:23:00 +08:00
qhy
bc78815acf 脚本参数暂时修改 2026-02-07 21:28:54 +08:00
qhy
d5f6577fa8 复制模型对象,跳过加载模型 2026-02-07 19:18:49 +08:00
qhy
7dcf9e8b89 VAE优化,模型直接加载至GPU 2026-02-07 17:36:00 +08:00
qhy
aba2a90045 算子融合 2026-02-07 16:40:33 +08:00
25de36b9bc 添加当前优化说明
相关参数改动和效果
2026-01-19 16:58:37 +08:00
2fdcec6da0 Delete README.md 2026-01-19 16:39:49 +08:00
9 changed files with 868 additions and 1412 deletions

245
README.md
View File

@@ -1,228 +1,29 @@
# UnifoLM-WMA-0: A World-Model-Action (WMA) Framework under UnifoLM Family # World Model Interaction 混合精度加速记录case1
<p style="font-size: 1.2em;">
<a href="https://unigen-x.github.io/unifolm-world-model-action.github.io"><strong>Project Page</strong></a> |
<a href="https://huggingface.co/collections/unitreerobotics/unifolm-wma-0-68ca23027310c0ca0f34959c"><strong>Models</strong></a> |
<a href="https://huggingface.co/unitreerobotics/datasets"><strong>Dataset</strong></a>
</p>
<div align="center">
<p align="right">
<span> 🌎English </span> | <a href="README_cn.md"> 🇨🇳中文 </a>
</p>
</div>
<div align="justify">
<b>UnifoLM-WMA-0</b> is Unitrees open-source world-modelaction architecture spanning multiple types of robotic embodiments, designed specifically for general-purpose robot learning. Its core component is a world-model capable of understanding the physical interactions between robots and the environments. This world-model provides two key functions: (a) <b>Simulation Engine</b> operates as an interactive simulator to generate synthetic data for robot learning; (b) <b>Policy Enhancement</b> connects with an action head and, by predicting future interaction processes with the world-model, further optimizes decision-making performance.
</div>
## 🦾 Real-Robot Demonstrations ## 变更位置
| <img src="assets/gifs/real_z1_stackbox.gif" style="border:none;box-shadow:none;margin:0;padding:0;" /> | <img src="assets/gifs/real_dual_stackbox.gif" style="border:none;box-shadow:none;margin:0;padding:0;" /> | - 脚本路径:`/home/dyz/unifolm-world-model-action/unitree_g1_pack_camera/case1/run_world_model_interaction.sh`
|:---:|:---:| - 当前状态:已修改了部分原本不建议修改/需要谨慎修改的参数(后续会在确认最优后固化为默认)。
| <img src="assets/gifs/real_cleanup_pencils.gif" style="border:none;box-shadow:none;margin:0;padding:0;" /> | <img src="assets/gifs/real_g1_pack_camera.gif" style="border:none;box-shadow:none;margin:0;padding:0;" /> |
**Note: the top-right window shows the world models pretion of future action videos.** ## 新增参数(确认最优后可变为默认)
- `--diffusion_dtype {fp32,bf16}`Diffusion 权重与前向 dtype默认 `fp32`
- `--projector_mode {fp32,autocast,bf16_full}`Projector 精度策略,默认 `fp32`
- `--encoder_mode {fp32,autocast,bf16_full}`Encoder 精度策略,默认 `fp32`
- `--vae_dtype {fp32,bf16}`VAE 权重与前向 dtype默认 `fp32`
- `--export_casted_ckpt <path>`:按当前精度设置导出 ckpt用于离线导出混合精度权重
- `--export_only`:只导出 ckpt 后退出,默认关闭
## 🔥 News ### 参数语义约定
- `fp32`:权重 + 前向均使用 fp32
- `autocast`:权重保持 fp32forward 在 `torch.autocast` 下运行(算子级混精)
- `bf16_full`:权重显式转换为 bf16forward 也以 bf16 为主
* Sep 22, 2025: 🚀 We released the deployment code for assisting experiments with [Unitree](https://www.unitree.com/) robots. ## 当前最优配置与结果
* Sep 15, 2025: 🚀 We released the training and inference code along with the model weights of [**UnifoLM-WMA-0**](https://huggingface.co/collections/unitreerobotics/unifolm-wma-0-68ca23027310c0ca0f34959c). ### 配置
- 除 VAE 模块外,其它模块全部 bf16
- 模型离线导出混合精度 ckpt使用 `--export_casted_ckpt`
## 📑 Opensource Plan ### 结果
- [x] Training - 耗时:从 `15m6s` 降到 `7m5s`
- [x] Inference - PSNR下降不到 `4``35 -> 31`
- [x] Checkpoints - 显存:占用降到原本约 `50%`
- [x] Deployment
## ⚙️ Installation
```
conda create -n unifolm-wma python==3.10.18
conda activate unifolm-wma
conda install pinocchio=3.2.0 -c conda-forge -y
conda install ffmpeg=7.1.1 -c conda-forge
git clone --recurse-submodules https://github.com/unitreerobotics/unifolm-world-model-action.git
# If you already downloaded the repo:
cd unifolm-world-model-action
git submodule update --init --recursive
pip install -e .
cd external/dlimp
pip install -e .
```
## 🧰 Model Checkpoints
| Model | Description | Link|
|---------|-------|------|
|$\text{UnifoLM-WMA-0}_{Base}$| Fine-tuned on [Open-X](https://robotics-transformer-x.github.io/) dataset. | [HuggingFace](https://huggingface.co/unitreerobotics/UnifoLM-WMA-0-Base)|
|$\text{UnifoLM-WMA-0}_{Dual}$| Fine-tuned on five [Unitree opensource dataset](https://huggingface.co/collections/unitreerobotics/g1-dex1-datasets-68bae98bf0a26d617f9983ab) in both decision-making and simulation modes. | [HuggingFace](https://huggingface.co/unitreerobotics/UnifoLM-WMA-0-Dual)|
## 🛢️ Dataset
In our experiments, we consider the following three opensource dataset:
| Dataset | Robot | Link |
|---------|-------|------|
|Z1_StackBox| [Unitree Z1](https://www.unitree.com/z1)|[Huggingface](https://huggingface.co/datasets/unitreerobotics/Z1_StackBox_Dataset/tree/v2.1)|
|Z1_DualArm_StackBox|[Unitree Z1](https://www.unitree.com/z1)|[Huggingface](https://huggingface.co/datasets/unitreerobotics/Z1_Dual_Dex1_StackBox_Dataset/tree/v2.1)|
|Z1_DualArm_StackBox_V2|[Unitree Z1](https://www.unitree.com/z1)|[Huggingface](https://huggingface.co/datasets/unitreerobotics/Z1_Dual_Dex1_StackBox_Dataset_V2/tree/v2.1)|
|Z1_DualArm_Cleanup_Pencils|[Unitree Z1](https://www.unitree.com/z1)|[Huggingface](https://huggingface.co/datasets/unitreerobotics/Z1_Dual_Dex1_CleanupPencils_Dataset/tree/v2.1)|
|G1_Pack_Camera|[Unitree G1](https://www.unitree.com/g1)|[Huggingface](https://huggingface.co/datasets/unitreerobotics/G1_Dex1_MountCameraRedGripper_Dataset/tree/v2.1)|
To train on your own dataset, first to have the data following the [Huggingface LeRobot V2.1](https://github.com/huggingface/lerobot) dataset format. Assume the datasets source directory structure is as follows:
```
source_dir/
├── dataset1_name
├── dataset2_name
├── dataset3_name
└── ...
```
Then, convert a dataset to the required format using the command below:
```python
cd prepare_data
python prepare_training_data.py \
--source_dir /path/to/your/source_dir \
--target_dir /path/to/save/the/converted/data \
--dataset_name "dataset1_name" \
--robot_name "a tag of the robot in the dataset" # e.g, Unitree Z1 Robot Arm or Unitree G1 Robot with Gripper.
```
The resulting data structure (Note: model training only supports input from the main-view camera. If the dataset includes multiple views, remove the corresponding values from the ```data_dir``` column in the CSV file.
```
target_dir/
├── videos
│ ├──dataset1_name
│ │ ├──camera_view_dir
│ │ ├── 0.mp4
│ │ ├── 1.mp4
│ │ └── ...
│ └── ...
├── transitions
│ ├── dataset1_name
│ ├── meta_data
│ ├── 0.h5
│ ├── 1.h5
│ └── ...
└── dataset1_name.csv
```
## 🚴‍♂️ Training
A. Our training strategy is outlined as follows:
- **Step 1**: Fine-tune a video generation model as the world model using the [Open-X](https://robotics-transformer-x.github.io/) dataset;
- **Step 2**: Post-train $\text{UnifoLM-WMA}$ in decision-making mode on the downstream task dataset;
<div align="left">
<img src="assets/pngs/dm_mode.png" width="600">
</div>
- **Step 3**: Post-train $\text{UnifoLM-WMA}$ in simulation mode on the downstream task dataset.
<div align="left">
<img src="assets/pngs/sim_mode.png" width="600">
</div>
**Note**: If you only require $\text{UnifoLM-WMA}$ to operate in a single mode, you may skip the corresponding step.
B. To conduct training on a single or multiple datasets, please follow the steps below:
- **Step 1**: The maximum DoF is assumed to be 16, if you have more than 16 DoF, update ```agent_state_dim``` and ```agent_action_dim``` in [configs/train/config.yaml](https://github.com/unitreerobotics/unifolm-wma/blob/working/configs/train/config.yaml) ;
- **Step 2**: Set up the input shapes for each modality in [configs/train/meta.json](https://github.com/unitreerobotics/unitree-world-model/blob/main/configs/train/meta.json);
- **Step 3**: Configure the training parameters in [configs/train/config.yaml](https://github.com/unitreerobotics/unitree-world-model/blob/main/configs/train/config.yaml). For the ```pretrained_checkpoint```, we recommend using the checkpoint " $\text{UnifoLM-WMA-0}_{Base}$ " fine-tuned on the [Open-X](https://robotics-transformer-x.github.io/) dataset;
```yaml
model:
pretrained_checkpoint: /path/to/pretrained/checkpoint;
...
decision_making_only: True # Train the world model only in decision-making mode. If False, jointly train it in both decision-making and simulation modes.
...
data:
...
train:
...
data_dir: /path/to/training/dataset/directory
dataset_and_weights: # list the name of each dataset below and make sure the summation of weights is 1.0
dataset1_name: 0.2
dataset2_name: 0.2
dataset3_name: 0.2
dataset4_name: 0.2
dataset5_name: 0.2
```
- **Step 4**: Setup ```experiment_name```, ```save_root``` variables in [scripts/train.sh](https://github.com/unitreerobotics/unitree-world-model/blob/main/scripts/train.sh);
- **Step 5**: Launch the training with the command:
```
bash scripts/train.sh
```
## 🌏 Inference under Interactive Simulation Mode
To run the world model in an interactive simulation mode, follow these steps:
- **Step 1**: (Skip this step if you just would like to test using the examples we provided) Prepare your own prompt following the format used in the [examples/world_model_interaction_prompts](https://github.com/unitreerobotics/unitree-world-model/tree/main/examples/world_model_interaction_prompts):
```
world_model_interaction_prompts/
├── images
│ ├── dataset1_name
│ │ ├── 0.png # Image prompt
│ │ └── ...
│ └── ...
├── transitions
│ ├── dataset1_name
│ │ ├── meta_data # Used for normalization
│ │ ├── 0.h # Robot state and action data; in interaction mode,
│ │ │ # only used to retrieve the robot state corresponding
│ │ │ # to the image prompt
│ │ └── ...
│ └── ...
├── dataset1_name.csv # File for loading image prompts, text instruction and corresponding robot states
└── ...
```
- **Step 2**: Specify the correct paths for ```pretrained_checkpoint```(e.g, $\text{UnifoLM-WMA-0}_{Dual}$) and ```data_dir``` in [configs/inference/world_model_interaction.yaml](https://github.com/unitreerobotics/unitree-world-model/blob/main/configs/inference/world_model_interaction.yaml)
- **Step 3**: Set the paths for ```checkpoint```, ```res_dir``` and ```prompt_dir``` in [scripts/run_world_model_interaction.sh](https://github.com/unitreerobotics/unitree-world-model/blob/main/scripts/run_world_model_interaction.sh), and specify all the dataset's name in ```datasets=(...)```. Then, launch the inference with the command:
```
bash scripts/run_world_model_interaction.sh
```
## 🧠 Inference and Deployment under Decision-Making Mode
In this setup, inference is performed on a server, while a robot client gathers observations from the real-robot and sends them to the server to query actions. The process unfolds through the following steps:
### Server Setup:
- **Step-1**: Specify ```ckpt```, ```res_dir```, ```datasets``` in [scripts/run_real_eval_server.sh](https://github.com/unitreerobotics/unifolm-world-model-action/blob/main/scripts/run_real_eval_server.sh);
- **Step-2**: Configure ```data_dir``` and ```dataset_and_weights``` in [config/inference/world_model_decision_making.yaml](https://github.com/unitreerobotics/unifolm-world-model-action/blob/f12b4782652ca00452941d851b17446e4ee7124a/configs/inference/world_model_decision_making.yaml#L225);
- **Step-3**: Launch the server:
```
conda activate unifolm-wma
cd unifolm-world-model-action
bash scripts/run_real_eval_server.sh
```
### Client Setup
- **Step-1**: Follow the instructions in [unitree_deploy/README.md](https://github.com/unitreerobotics/unifolm-world-model-action/blob/main/unitree_deploy/README.md) to create the ```unitree_deploy``` conda environment, install the required packages, launch the controllers or services on the real-robot.
- **Step-2**: Open a new terminal and establish a tunnel connection from the client to the server:
```
ssh user_name@remote_server_IP -CNg -L 8000:127.0.0.1:8000
```
- **Step-3**: Run the ```unitree_deploy/robot_client.py``` script to start inference:
```
cd unitree_deploy
python scripts/robot_client.py --robot_type "g1_dex1" --action_horizon 16 --exe_steps 16 --observation_horizon 2 --language_instruction "pack black camera into box" --output_dir ./results --control_freq 15
```
## 📝 Codebase Architecture
Here's a high-level overview of the project's code structure and core components:
```
unitree-world-model/
├── assets # Media assets such as GIFs, images, and demo videos
├── configs # Configuration files for training and inference
│ ├── inference
│ └── train
├── examples # Example inputs and prompts for running inference
├── external # External packages
├── prepare_data # Scripts for dataset preprocessing and format conversion
├── scripts # Main scripts for training, evaluation, and deployment
├── src
│ ├──unitree_worldmodel # Core Python package for the Unitree world model
│ │ ├── data # Dataset loading, transformations, and dataloaders
│ │ ├── models # Model architectures and backbone definitions
│ │ ├── modules # Custom model modules and components
│ │ └── utils # Utility functions and common helpers
└── unitree_deploy # Deployment code
```
## 🙏 Acknowledgement
Lots of code are inherited from [DynamiCrafter](https://github.com/Doubiiu/DynamiCrafter), [Diffusion Policy](https://github.com/real-stanford/diffusion_policy), [ACT](https://github.com/MarkFzp/act-plus-plus) and [HPT](https://github.com/liruiw/HPT).
## 📝 Citation
```
@misc{unifolm-wma-0,
author = {Unitree},
title = {UnifoLM-WMA-0: A World-Model-Action (WMA) Framework under UnifoLM Family},
year = {2025},
}
```

View File

@@ -222,7 +222,7 @@ data:
test: test:
target: unifolm_wma.data.wma_data.WMAData target: unifolm_wma.data.wma_data.WMAData
params: params:
data_dir: '/home/dyz/unifolm-world-model-action/examples/world_model_interaction_prompts' data_dir: '/home/qhy/unifolm-world-model-action/examples/world_model_interaction_prompts'
video_length: ${model.params.wma_config.params.temporal_length} video_length: ${model.params.wma_config.params.temporal_length}
frame_stride: 2 frame_stride: 2
load_raw_resolution: True load_raw_resolution: True

View File

@@ -9,10 +9,9 @@ import logging
import einops import einops
import warnings import warnings
import imageio import imageio
import time import atexit
import json from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager, nullcontext from contextlib import nullcontext
from dataclasses import dataclass, field, asdict
from typing import Optional, Dict, List, Any, Mapping from typing import Optional, Dict, List, Any, Mapping
from pytorch_lightning import seed_everything from pytorch_lightning import seed_everything
@@ -21,376 +20,66 @@ from tqdm import tqdm
from einops import rearrange, repeat from einops import rearrange, repeat
from collections import OrderedDict from collections import OrderedDict
from torch import nn from torch import nn
from eval_utils import populate_queues, log_to_tensorboard from eval_utils import populate_queues
from collections import deque from collections import deque
from torch import Tensor from torch import Tensor
from torch.utils.tensorboard import SummaryWriter
from PIL import Image from PIL import Image
from unifolm_wma.models.samplers.ddim import DDIMSampler from unifolm_wma.models.samplers.ddim import DDIMSampler
from unifolm_wma.utils.utils import instantiate_from_config from unifolm_wma.utils.utils import instantiate_from_config
# ========== Profiling Infrastructure ========== # ========== Async I/O ==========
@dataclass _io_executor: Optional[ThreadPoolExecutor] = None
class TimingRecord: _io_futures: List[Any] = []
"""Record for a single timing measurement."""
name: str
start_time: float = 0.0
end_time: float = 0.0
cuda_time_ms: float = 0.0
count: int = 0
children: List['TimingRecord'] = field(default_factory=list)
@property
def cpu_time_ms(self) -> float:
return (self.end_time - self.start_time) * 1000
def to_dict(self) -> dict:
return {
'name': self.name,
'cpu_time_ms': self.cpu_time_ms,
'cuda_time_ms': self.cuda_time_ms,
'count': self.count,
'children': [c.to_dict() for c in self.children]
}
class ProfilerManager: def _get_io_executor() -> ThreadPoolExecutor:
"""Manages macro and micro-level profiling.""" global _io_executor
if _io_executor is None:
_io_executor = ThreadPoolExecutor(max_workers=2)
return _io_executor
def __init__(
self,
enabled: bool = False,
output_dir: str = "./profile_output",
profile_detail: str = "light",
):
self.enabled = enabled
self.output_dir = output_dir
self.profile_detail = profile_detail
self.macro_timings: Dict[str, List[float]] = {}
self.cuda_events: Dict[str, List[tuple]] = {}
self.memory_snapshots: List[Dict] = []
self.pytorch_profiler = None
self.current_iteration = 0
self.operator_stats: Dict[str, Dict] = {}
self.profiler_config = self._build_profiler_config(profile_detail)
if enabled:
os.makedirs(output_dir, exist_ok=True)
def _build_profiler_config(self, profile_detail: str) -> Dict[str, Any]:
"""Return profiler settings based on the requested detail level."""
if profile_detail not in ("light", "full"):
raise ValueError(f"Unsupported profile_detail: {profile_detail}")
if profile_detail == "full":
return {
"record_shapes": True,
"profile_memory": True,
"with_stack": True,
"with_flops": True,
"with_modules": True,
"group_by_input_shape": True,
}
return {
"record_shapes": False,
"profile_memory": False,
"with_stack": False,
"with_flops": False,
"with_modules": False,
"group_by_input_shape": False,
}
@contextmanager
def profile_section(self, name: str, sync_cuda: bool = True):
"""Context manager for profiling a code section."""
if not self.enabled:
yield
return
if sync_cuda and torch.cuda.is_available():
torch.cuda.synchronize()
start_event = None
end_event = None
if torch.cuda.is_available():
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
start_time = time.perf_counter()
def _flush_io():
"""Wait for all pending async I/O to finish."""
global _io_futures
for fut in _io_futures:
try: try:
yield fut.result()
finally: except Exception as e:
if sync_cuda and torch.cuda.is_available(): print(f">>> [async I/O] error: {e}")
torch.cuda.synchronize() _io_futures.clear()
end_time = time.perf_counter()
cpu_time_ms = (end_time - start_time) * 1000
cuda_time_ms = 0.0
if start_event is not None and end_event is not None:
end_event.record()
torch.cuda.synchronize()
cuda_time_ms = start_event.elapsed_time(end_event)
if name not in self.macro_timings:
self.macro_timings[name] = []
self.macro_timings[name].append(cpu_time_ms)
if name not in self.cuda_events:
self.cuda_events[name] = []
self.cuda_events[name].append((cpu_time_ms, cuda_time_ms))
def record_memory(self, tag: str = ""):
"""Record current GPU memory state."""
if not self.enabled or not torch.cuda.is_available():
return
snapshot = {
'tag': tag,
'iteration': self.current_iteration,
'allocated_mb': torch.cuda.memory_allocated() / 1024**2,
'reserved_mb': torch.cuda.memory_reserved() / 1024**2,
'max_allocated_mb': torch.cuda.max_memory_allocated() / 1024**2,
}
self.memory_snapshots.append(snapshot)
def start_pytorch_profiler(self, wait: int = 1, warmup: int = 1, active: int = 3):
"""Start PyTorch profiler for operator-level analysis."""
if not self.enabled:
return nullcontext()
self.pytorch_profiler = torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
schedule=torch.profiler.schedule(
wait=wait, warmup=warmup, active=active, repeat=1
),
on_trace_ready=self._trace_handler,
record_shapes=self.profiler_config["record_shapes"],
profile_memory=self.profiler_config["profile_memory"],
with_stack=self.profiler_config["with_stack"],
with_flops=self.profiler_config["with_flops"],
with_modules=self.profiler_config["with_modules"],
)
return self.pytorch_profiler
def _trace_handler(self, prof):
"""Handle profiler trace output."""
trace_path = os.path.join(
self.output_dir,
f"trace_iter_{self.current_iteration}.json"
)
prof.export_chrome_trace(trace_path)
# Extract operator statistics
key_averages = prof.key_averages(
group_by_input_shape=self.profiler_config["group_by_input_shape"]
)
for evt in key_averages:
op_name = evt.key
if op_name not in self.operator_stats:
self.operator_stats[op_name] = {
'count': 0,
'cpu_time_total_us': 0,
'cuda_time_total_us': 0,
'self_cpu_time_total_us': 0,
'self_cuda_time_total_us': 0,
'cpu_memory_usage': 0,
'cuda_memory_usage': 0,
'flops': 0,
}
stats = self.operator_stats[op_name]
stats['count'] += evt.count
stats['cpu_time_total_us'] += evt.cpu_time_total
stats['cuda_time_total_us'] += evt.cuda_time_total
stats['self_cpu_time_total_us'] += evt.self_cpu_time_total
stats['self_cuda_time_total_us'] += evt.self_cuda_time_total
if hasattr(evt, 'cpu_memory_usage'):
stats['cpu_memory_usage'] += evt.cpu_memory_usage
if hasattr(evt, 'cuda_memory_usage'):
stats['cuda_memory_usage'] += evt.cuda_memory_usage
if hasattr(evt, 'flops') and evt.flops:
stats['flops'] += evt.flops
def step_profiler(self):
"""Step the PyTorch profiler."""
if self.pytorch_profiler is not None:
self.pytorch_profiler.step()
def generate_report(self) -> str:
"""Generate comprehensive profiling report."""
if not self.enabled:
return "Profiling disabled."
report_lines = []
report_lines.append("=" * 80)
report_lines.append("PERFORMANCE PROFILING REPORT")
report_lines.append("=" * 80)
report_lines.append("")
# Macro-level timing summary
report_lines.append("-" * 40)
report_lines.append("MACRO-LEVEL TIMING SUMMARY")
report_lines.append("-" * 40)
report_lines.append(f"{'Section':<40} {'Count':>8} {'Total(ms)':>12} {'Avg(ms)':>12} {'CUDA Avg(ms)':>14}")
report_lines.append("-" * 86)
total_time = 0
timing_data = []
for name, times in sorted(self.macro_timings.items()):
cuda_times = [ct for _, ct in self.cuda_events.get(name, [])]
avg_time = np.mean(times)
avg_cuda = np.mean(cuda_times) if cuda_times else 0
total = sum(times)
total_time += total
timing_data.append({
'name': name,
'count': len(times),
'total_ms': total,
'avg_ms': avg_time,
'cuda_avg_ms': avg_cuda,
'times': times,
'cuda_times': cuda_times,
})
report_lines.append(f"{name:<40} {len(times):>8} {total:>12.2f} {avg_time:>12.2f} {avg_cuda:>14.2f}")
report_lines.append("-" * 86)
report_lines.append(f"{'TOTAL':<40} {'':<8} {total_time:>12.2f}")
report_lines.append("")
# Memory summary
if self.memory_snapshots:
report_lines.append("-" * 40)
report_lines.append("GPU MEMORY SUMMARY")
report_lines.append("-" * 40)
max_alloc = max(s['max_allocated_mb'] for s in self.memory_snapshots)
avg_alloc = np.mean([s['allocated_mb'] for s in self.memory_snapshots])
report_lines.append(f"Peak allocated: {max_alloc:>10.2f} MB")
report_lines.append(f"Average allocated: {avg_alloc:>10.2f} MB")
report_lines.append("")
# Top operators by CUDA time
if self.operator_stats:
report_lines.append("-" * 40)
report_lines.append("TOP 30 OPERATORS BY CUDA TIME")
report_lines.append("-" * 40)
sorted_ops = sorted(
self.operator_stats.items(),
key=lambda x: x[1]['cuda_time_total_us'],
reverse=True
)[:30]
report_lines.append(f"{'Operator':<50} {'Count':>8} {'CUDA(ms)':>12} {'CPU(ms)':>12} {'Self CUDA(ms)':>14}")
report_lines.append("-" * 96)
for op_name, stats in sorted_ops:
# Truncate long operator names
display_name = op_name[:47] + "..." if len(op_name) > 50 else op_name
report_lines.append(
f"{display_name:<50} {stats['count']:>8} "
f"{stats['cuda_time_total_us']/1000:>12.2f} "
f"{stats['cpu_time_total_us']/1000:>12.2f} "
f"{stats['self_cuda_time_total_us']/1000:>14.2f}"
)
report_lines.append("")
# Compute category breakdown
report_lines.append("-" * 40)
report_lines.append("OPERATOR CATEGORY BREAKDOWN")
report_lines.append("-" * 40)
categories = {
'Attention': ['attention', 'softmax', 'bmm', 'baddbmm'],
'Convolution': ['conv', 'cudnn'],
'Normalization': ['norm', 'layer_norm', 'batch_norm', 'group_norm'],
'Activation': ['relu', 'gelu', 'silu', 'sigmoid', 'tanh'],
'Linear/GEMM': ['linear', 'addmm', 'mm', 'gemm'],
'Memory': ['copy', 'contiguous', 'view', 'reshape', 'permute', 'transpose'],
'Elementwise': ['add', 'mul', 'div', 'sub', 'pow', 'exp', 'sqrt'],
}
category_times = {cat: 0.0 for cat in categories}
category_times['Other'] = 0.0
for op_name, stats in self.operator_stats.items():
op_lower = op_name.lower()
categorized = False
for cat, keywords in categories.items():
if any(kw in op_lower for kw in keywords):
category_times[cat] += stats['cuda_time_total_us']
categorized = True
break
if not categorized:
category_times['Other'] += stats['cuda_time_total_us']
total_op_time = sum(category_times.values())
report_lines.append(f"{'Category':<30} {'CUDA Time(ms)':>15} {'Percentage':>12}")
report_lines.append("-" * 57)
for cat, time_us in sorted(category_times.items(), key=lambda x: -x[1]):
pct = (time_us / total_op_time * 100) if total_op_time > 0 else 0
report_lines.append(f"{cat:<30} {time_us/1000:>15.2f} {pct:>11.1f}%")
report_lines.append("")
report = "\n".join(report_lines)
return report
def save_results(self):
"""Save all profiling results to files."""
if not self.enabled:
return
# Save report
report = self.generate_report()
report_path = os.path.join(self.output_dir, "profiling_report.txt")
with open(report_path, 'w') as f:
f.write(report)
print(f">>> Profiling report saved to: {report_path}")
# Save detailed JSON data
data = {
'macro_timings': {
name: {
'times': times,
'cuda_times': [ct for _, ct in self.cuda_events.get(name, [])]
}
for name, times in self.macro_timings.items()
},
'memory_snapshots': self.memory_snapshots,
'operator_stats': self.operator_stats,
}
json_path = os.path.join(self.output_dir, "profiling_data.json")
with open(json_path, 'w') as f:
json.dump(data, f, indent=2)
print(f">>> Detailed profiling data saved to: {json_path}")
# Print summary to console
print("\n" + report)
# Global profiler instance atexit.register(_flush_io)
_profiler: Optional[ProfilerManager] = None
def get_profiler() -> ProfilerManager:
"""Get the global profiler instance."""
global _profiler
if _profiler is None:
_profiler = ProfilerManager(enabled=False)
return _profiler
def init_profiler(enabled: bool, output_dir: str, profile_detail: str) -> ProfilerManager: def _save_results_sync(video_cpu: Tensor, filename: str, fps: int) -> None:
"""Initialize the global profiler.""" """Synchronous save on CPU tensor (runs in background thread)."""
global _profiler video = torch.clamp(video_cpu.float(), -1., 1.)
_profiler = ProfilerManager( n = video.shape[0]
enabled=enabled, video = video.permute(2, 0, 1, 3, 4)
output_dir=output_dir, frame_grids = [
profile_detail=profile_detail, torchvision.utils.make_grid(framesheet, nrow=int(n), padding=0)
) for framesheet in video
return _profiler ]
grid = torch.stack(frame_grids, dim=0)
grid = (grid + 1.0) / 2.0
grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1)
torchvision.io.write_video(filename,
grid,
fps=fps,
video_codec='h264',
options={'crf': '10'})
def save_results_async(video: Tensor, filename: str, fps: int = 8) -> None:
"""Submit video saving to background thread pool."""
video_cpu = video.detach().cpu()
fut = _get_io_executor().submit(_save_results_sync, video_cpu, filename, fps)
_io_futures.append(fut)
# ========== Original Functions ========== # ========== Original Functions ==========
@@ -458,7 +147,8 @@ def _load_state_dict(model: nn.Module,
def load_model_checkpoint(model: nn.Module, def load_model_checkpoint(model: nn.Module,
ckpt: str, ckpt: str,
assign: bool | None = None) -> nn.Module: assign: bool | None = None,
device: str | torch.device = "cpu") -> nn.Module:
"""Load model weights from checkpoint file. """Load model weights from checkpoint file.
Args: Args:
@@ -467,11 +157,12 @@ def load_model_checkpoint(model: nn.Module,
assign (bool | None): Whether to preserve checkpoint tensor dtypes assign (bool | None): Whether to preserve checkpoint tensor dtypes
via load_state_dict(assign=True). If None, auto-enable when a via load_state_dict(assign=True). If None, auto-enable when a
casted checkpoint metadata is detected. casted checkpoint metadata is detected.
device (str | torch.device): Target device for loaded tensors.
Returns: Returns:
nn.Module: Model with loaded weights. nn.Module: Model with loaded weights.
""" """
ckpt_data = torch.load(ckpt, map_location="cpu") ckpt_data = torch.load(ckpt, map_location=device, mmap=True)
use_assign = False use_assign = False
if assign is not None: if assign is not None:
use_assign = assign use_assign = assign
@@ -511,9 +202,7 @@ def load_model_checkpoint(model: nn.Module,
def maybe_cast_module(module: nn.Module | None, def maybe_cast_module(module: nn.Module | None,
dtype: torch.dtype, dtype: torch.dtype,
label: str, label: str) -> None:
profiler: Optional[ProfilerManager] = None,
profile_name: Optional[str] = None) -> None:
if module is None: if module is None:
return return
try: try:
@@ -524,10 +213,6 @@ def maybe_cast_module(module: nn.Module | None,
if param.dtype == dtype: if param.dtype == dtype:
print(f">>> {label} already {dtype}; skip cast") print(f">>> {label} already {dtype}; skip cast")
return return
ctx = nullcontext()
if profiler is not None and profile_name:
ctx = profiler.profile_section(profile_name)
with ctx:
module.to(dtype=dtype) module.to(dtype=dtype)
print(f">>> {label} cast to {dtype}") print(f">>> {label} cast to {dtype}")
@@ -746,8 +431,6 @@ def get_latent_z(model, videos: Tensor) -> Tensor:
Returns: Returns:
Tensor: Latent video tensor of shape [B, C, T, H, W]. Tensor: Latent video tensor of shape [B, C, T, H, W].
""" """
profiler = get_profiler()
with profiler.profile_section("get_latent_z/encode"):
b, c, t, h, w = videos.shape b, c, t, h, w = videos.shape
x = rearrange(videos, 'b c t h w -> (b t) c h w') x = rearrange(videos, 'b c t h w -> (b t) c h w')
vae_ctx = nullcontext() vae_ctx = nullcontext()
@@ -862,8 +545,6 @@ def image_guided_synthesis_sim_mode(
actions (torch.Tensor): Predicted action sequences [B, T, D] from diffusion decoding. actions (torch.Tensor): Predicted action sequences [B, T, D] from diffusion decoding.
states (torch.Tensor): Predicted state sequences [B, T, D] from diffusion decoding. states (torch.Tensor): Predicted state sequences [B, T, D] from diffusion decoding.
""" """
profiler = get_profiler()
b, _, t, _, _ = noise_shape b, _, t, _, _ = noise_shape
ddim_sampler = getattr(model, "_ddim_sampler", None) ddim_sampler = getattr(model, "_ddim_sampler", None)
if ddim_sampler is None: if ddim_sampler is None:
@@ -873,7 +554,6 @@ def image_guided_synthesis_sim_mode(
fs = torch.tensor([fs] * batch_size, dtype=torch.long, device=model.device) fs = torch.tensor([fs] * batch_size, dtype=torch.long, device=model.device)
with profiler.profile_section("synthesis/conditioning_prep"):
img = observation['observation.images.top'].permute(0, 2, 1, 3, 4) img = observation['observation.images.top'].permute(0, 2, 1, 3, 4)
cond_img = rearrange(img, 'b o c h w -> (b o) c h w')[-1:] cond_img = rearrange(img, 'b o c h w -> (b o) c h w')[-1:]
if getattr(model, "encoder_bf16", False) and model.device.type == "cuda": if getattr(model, "encoder_bf16", False) and model.device.type == "cuda":
@@ -959,7 +639,6 @@ def image_guided_synthesis_sim_mode(
cond_z0 = None cond_z0 = None
if ddim_sampler is not None: if ddim_sampler is not None:
with profiler.profile_section("synthesis/ddim_sampling"):
autocast_ctx = nullcontext() autocast_ctx = nullcontext()
if diffusion_autocast_dtype is not None and model.device.type == "cuda": if diffusion_autocast_dtype is not None and model.device.type == "cuda":
autocast_ctx = torch.autocast("cuda", dtype=diffusion_autocast_dtype) autocast_ctx = torch.autocast("cuda", dtype=diffusion_autocast_dtype)
@@ -982,7 +661,6 @@ def image_guided_synthesis_sim_mode(
**kwargs) **kwargs)
# Reconstruct from latent to pixel space # Reconstruct from latent to pixel space
with profiler.profile_section("synthesis/decode_first_stage"):
if getattr(model, "vae_bf16", False): if getattr(model, "vae_bf16", False):
if samples.dtype != torch.bfloat16: if samples.dtype != torch.bfloat16:
samples = samples.to(dtype=torch.bfloat16) samples = samples.to(dtype=torch.bfloat16)
@@ -1012,21 +690,31 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
Returns: Returns:
None None
""" """
profiler = get_profiler() # Create inference dir
# Create inference and tensorboard dirs
os.makedirs(args.savedir + '/inference', exist_ok=True) os.makedirs(args.savedir + '/inference', exist_ok=True)
log_dir = args.savedir + f"/tensorboard"
os.makedirs(log_dir, exist_ok=True)
writer = SummaryWriter(log_dir=log_dir)
# Load prompt # Load prompt
csv_path = os.path.join(args.prompt_dir, f"{args.dataset}.csv") csv_path = os.path.join(args.prompt_dir, f"{args.dataset}.csv")
df = pd.read_csv(csv_path) df = pd.read_csv(csv_path)
# Load config # Load config (always needed for data setup)
with profiler.profile_section("model_loading/config"):
config = OmegaConf.load(args.config) config = OmegaConf.load(args.config)
prepared_path = args.ckpt_path + ".prepared.pt"
if os.path.exists(prepared_path):
# ---- Fast path: load the fully-prepared model ----
print(f">>> Loading prepared model from {prepared_path} ...")
model = torch.load(prepared_path,
map_location=f"cuda:{gpu_no}",
weights_only=False,
mmap=True)
model.eval()
diffusion_autocast_dtype = (torch.bfloat16
if args.diffusion_dtype == "bf16"
else None)
print(f">>> Prepared model loaded.")
else:
# ---- Normal path: construct + checkpoint + casting ----
config['model']['params']['wma_config']['params'][ config['model']['params']['wma_config']['params'][
'use_checkpoint'] = False 'use_checkpoint'] = False
model = instantiate_from_config(config.model) model = instantiate_from_config(config.model)
@@ -1034,30 +722,18 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
assert os.path.exists(args.ckpt_path), "Error: checkpoint Not Found!" assert os.path.exists(args.ckpt_path), "Error: checkpoint Not Found!"
with profiler.profile_section("model_loading/checkpoint"): model = load_model_checkpoint(model, args.ckpt_path,
model = load_model_checkpoint(model, args.ckpt_path) device=f"cuda:{gpu_no}")
model.eval() model.eval()
model = model.cuda(gpu_no) # move residual buffers not in state_dict
print(f'>>> Load pre-trained model ...') print(f'>>> Load pre-trained model ...')
# Build unnomalizer
logging.info("***** Configing Data *****")
with profiler.profile_section("data_loading"):
data = instantiate_from_config(config.data)
data.setup()
print(">>> Dataset is successfully loaded ...")
with profiler.profile_section("model_to_cuda"):
model = model.cuda(gpu_no)
device = get_device_from_parameters(model)
diffusion_autocast_dtype = None diffusion_autocast_dtype = None
if args.diffusion_dtype == "bf16": if args.diffusion_dtype == "bf16":
maybe_cast_module( maybe_cast_module(
model.model, model.model,
torch.bfloat16, torch.bfloat16,
"diffusion backbone", "diffusion backbone",
profiler=profiler,
profile_name="model_loading/diffusion_bf16",
) )
diffusion_autocast_dtype = torch.bfloat16 diffusion_autocast_dtype = torch.bfloat16
print(">>> diffusion backbone set to bfloat16") print(">>> diffusion backbone set to bfloat16")
@@ -1068,12 +744,29 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
model.first_stage_model, model.first_stage_model,
vae_weight_dtype, vae_weight_dtype,
"VAE", "VAE",
profiler=profiler,
profile_name="model_loading/vae_cast",
) )
model.vae_bf16 = args.vae_dtype == "bf16" model.vae_bf16 = args.vae_dtype == "bf16"
print(f">>> VAE dtype set to {args.vae_dtype}") print(f">>> VAE dtype set to {args.vae_dtype}")
# --- VAE performance optimizations ---
if hasattr(model, "first_stage_model") and model.first_stage_model is not None:
vae = model.first_stage_model
# torch.compile: fuses GroupNorm+SiLU, conv chains, etc.
if args.vae_compile:
vae.decoder = torch.compile(vae.decoder, mode="reduce-overhead")
vae.encoder = torch.compile(vae.encoder, mode="reduce-overhead")
print(">>> VAE encoder/decoder compiled with torch.compile (reduce-overhead)")
# Batch decode size
vae_decode_bs = args.vae_decode_bs if args.vae_decode_bs > 0 else 9999
model.vae_decode_bs = vae_decode_bs
model.vae_encode_bs = vae_decode_bs
if args.vae_decode_bs > 0:
print(f">>> VAE encode/decode batch size set to {args.vae_decode_bs}")
else:
print(">>> VAE encode/decode batch size: all frames at once")
encoder_mode = args.encoder_mode encoder_mode = args.encoder_mode
encoder_bf16 = encoder_mode in ("autocast", "bf16_full") encoder_bf16 = encoder_mode in ("autocast", "bf16_full")
encoder_weight_dtype = torch.bfloat16 if encoder_mode == "bf16_full" else torch.float32 encoder_weight_dtype = torch.bfloat16 if encoder_mode == "bf16_full" else torch.float32
@@ -1082,16 +775,12 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
model.cond_stage_model, model.cond_stage_model,
encoder_weight_dtype, encoder_weight_dtype,
"cond_stage_model", "cond_stage_model",
profiler=profiler,
profile_name="model_loading/encoder_cond_cast",
) )
if hasattr(model, "embedder") and model.embedder is not None: if hasattr(model, "embedder") and model.embedder is not None:
maybe_cast_module( maybe_cast_module(
model.embedder, model.embedder,
encoder_weight_dtype, encoder_weight_dtype,
"embedder", "embedder",
profiler=profiler,
profile_name="model_loading/encoder_embedder_cast",
) )
model.encoder_bf16 = encoder_bf16 model.encoder_bf16 = encoder_bf16
model.encoder_mode = encoder_mode model.encoder_mode = encoder_mode
@@ -1107,24 +796,18 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
model.image_proj_model, model.image_proj_model,
projector_weight_dtype, projector_weight_dtype,
"image_proj_model", "image_proj_model",
profiler=profiler,
profile_name="model_loading/projector_image_cast",
) )
if hasattr(model, "state_projector") and model.state_projector is not None: if hasattr(model, "state_projector") and model.state_projector is not None:
maybe_cast_module( maybe_cast_module(
model.state_projector, model.state_projector,
projector_weight_dtype, projector_weight_dtype,
"state_projector", "state_projector",
profiler=profiler,
profile_name="model_loading/projector_state_cast",
) )
if hasattr(model, "action_projector") and model.action_projector is not None: if hasattr(model, "action_projector") and model.action_projector is not None:
maybe_cast_module( maybe_cast_module(
model.action_projector, model.action_projector,
projector_weight_dtype, projector_weight_dtype,
"action_projector", "action_projector",
profiler=profiler,
profile_name="model_loading/projector_action_cast",
) )
if hasattr(model, "projector_bf16"): if hasattr(model, "projector_bf16"):
model.projector_bf16 = projector_bf16 model.projector_bf16 = projector_bf16
@@ -1148,7 +831,18 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
print(">>> export_only set; skipping inference.") print(">>> export_only set; skipping inference.")
return return
profiler.record_memory("after_model_load") # Save prepared model for fast loading next time
if prepared_path:
print(f">>> Saving prepared model to {prepared_path} ...")
torch.save(model, prepared_path)
print(f">>> Prepared model saved ({os.path.getsize(prepared_path) / 1024**3:.1f} GB).")
# Build normalizer (always needed, independent of model loading path)
logging.info("***** Configing Data *****")
data = instantiate_from_config(config.data)
data.setup()
print(">>> Dataset is successfully loaded ...")
device = get_device_from_parameters(model)
# Run over data # Run over data
assert (args.height % 16 == 0) and ( assert (args.height % 16 == 0) and (
@@ -1163,10 +857,6 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
print(f'>>> Generate {n_frames} frames under each generation ...') print(f'>>> Generate {n_frames} frames under each generation ...')
noise_shape = [args.bs, channels, n_frames, h, w] noise_shape = [args.bs, channels, n_frames, h, w]
# Determine profiler iterations
profile_active_iters = getattr(args, 'profile_iterations', 3)
use_pytorch_profiler = profiler.enabled and profile_active_iters > 0
# Start inference # Start inference
for idx in range(0, len(df)): for idx in range(0, len(df)):
sample = df.iloc[idx] sample = df.iloc[idx]
@@ -1182,7 +872,6 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
# Load transitions to get the initial state later # Load transitions to get the initial state later
transition_path = get_transition_path(args.prompt_dir, sample) transition_path = get_transition_path(args.prompt_dir, sample)
with profiler.profile_section("load_transitions"):
with h5py.File(transition_path, 'r') as h5f: with h5py.File(transition_path, 'r') as h5f:
transition_dict = {} transition_dict = {}
for key in h5f.keys(): for key in h5f.keys():
@@ -1210,7 +899,6 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
} }
# Obtain initial frame and state # Obtain initial frame and state
with profiler.profile_section("prepare_init_input"):
start_idx = 0 start_idx = 0
model_input_fs = ori_fps // fs model_input_fs = ori_fps // fs
batch, ori_state_dim, ori_action_dim = prepare_init_input( batch, ori_state_dim, ori_action_dim = prepare_init_input(
@@ -1233,24 +921,12 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
# Update observation queues # Update observation queues
cond_obs_queues = populate_queues(cond_obs_queues, observation) cond_obs_queues = populate_queues(cond_obs_queues, observation)
# Setup PyTorch profiler context if enabled
pytorch_prof_ctx = nullcontext()
if use_pytorch_profiler:
pytorch_prof_ctx = profiler.start_pytorch_profiler(
wait=1, warmup=1, active=profile_active_iters
)
# Multi-round interaction with the world-model # Multi-round interaction with the world-model
with pytorch_prof_ctx:
for itr in tqdm(range(args.n_iter)): for itr in tqdm(range(args.n_iter)):
log_every = max(1, args.step_log_every) log_every = max(1, args.step_log_every)
log_step = (itr % log_every == 0) log_step = (itr % log_every == 0)
profiler.current_iteration = itr
profiler.record_memory(f"iter_{itr}_start")
with profiler.profile_section("iteration_total"):
# Get observation # Get observation
with profiler.profile_section("prepare_observation"):
observation = { observation = {
'observation.images.top': 'observation.images.top':
torch.stack(list( torch.stack(list(
@@ -1267,7 +943,6 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
# Use world-model in policy to generate action # Use world-model in policy to generate action
if log_step: if log_step:
print(f'>>> Step {itr}: generating actions ...') print(f'>>> Step {itr}: generating actions ...')
with profiler.profile_section("action_generation"):
pred_videos_0, pred_actions, _ = image_guided_synthesis_sim_mode( pred_videos_0, pred_actions, _ = image_guided_synthesis_sim_mode(
model, model,
sample['instruction'], sample['instruction'],
@@ -1285,7 +960,6 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
diffusion_autocast_dtype=diffusion_autocast_dtype) diffusion_autocast_dtype=diffusion_autocast_dtype)
# Update future actions in the observation queues # Update future actions in the observation queues
with profiler.profile_section("update_action_queues"):
for act_idx in range(len(pred_actions[0])): for act_idx in range(len(pred_actions[0])):
obs_update = {'action': pred_actions[0][act_idx:act_idx + 1]} obs_update = {'action': pred_actions[0][act_idx:act_idx + 1]}
obs_update['action'][:, ori_action_dim:] = 0.0 obs_update['action'][:, ori_action_dim:] = 0.0
@@ -1293,7 +967,6 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
obs_update) obs_update)
# Collect data for interacting the world-model using the predicted actions # Collect data for interacting the world-model using the predicted actions
with profiler.profile_section("prepare_wm_observation"):
observation = { observation = {
'observation.images.top': 'observation.images.top':
torch.stack(list( torch.stack(list(
@@ -1310,7 +983,6 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
# Interaction with the world-model # Interaction with the world-model
if log_step: if log_step:
print(f'>>> Step {itr}: interacting with world model ...') print(f'>>> Step {itr}: interacting with world model ...')
with profiler.profile_section("world_model_interaction"):
pred_videos_1, _, pred_states = image_guided_synthesis_sim_mode( pred_videos_1, _, pred_states = image_guided_synthesis_sim_mode(
model, model,
"", "",
@@ -1327,7 +999,6 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
guidance_rescale=args.guidance_rescale, guidance_rescale=args.guidance_rescale,
diffusion_autocast_dtype=diffusion_autocast_dtype) diffusion_autocast_dtype=diffusion_autocast_dtype)
with profiler.profile_section("update_state_queues"):
for step_idx in range(args.exe_steps): for step_idx in range(args.exe_steps):
obs_update = { obs_update = {
'observation.images.top': 'observation.images.top':
@@ -1342,28 +1013,14 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
cond_obs_queues = populate_queues(cond_obs_queues, cond_obs_queues = populate_queues(cond_obs_queues,
obs_update) obs_update)
# Save the imagen videos for decision-making # Save the imagen videos for decision-making (async)
with profiler.profile_section("save_results"):
sample_tag = f"{args.dataset}-vid{sample['videoid']}-dm-fs-{fs}/itr-{itr}"
log_to_tensorboard(writer,
pred_videos_0,
sample_tag,
fps=args.save_fps)
# Save videos environment changes via world-model interaction
sample_tag = f"{args.dataset}-vid{sample['videoid']}-wd-fs-{fs}/itr-{itr}"
log_to_tensorboard(writer,
pred_videos_1,
sample_tag,
fps=args.save_fps)
# Save the imagen videos for decision-making
sample_video_file = f'{video_save_dir}/dm/{fs}/itr-{itr}.mp4' sample_video_file = f'{video_save_dir}/dm/{fs}/itr-{itr}.mp4'
save_results(pred_videos_0.cpu(), save_results_async(pred_videos_0,
sample_video_file, sample_video_file,
fps=args.save_fps) fps=args.save_fps)
# Save videos environment changes via world-model interaction # Save videos environment changes via world-model interaction
sample_video_file = f'{video_save_dir}/wm/{fs}/itr-{itr}.mp4' sample_video_file = f'{video_save_dir}/wm/{fs}/itr-{itr}.mp4'
save_results(pred_videos_1.cpu(), save_results_async(pred_videos_1,
sample_video_file, sample_video_file,
fps=args.save_fps) fps=args.save_fps)
@@ -1371,20 +1028,12 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
# Collect the result of world-model interactions # Collect the result of world-model interactions
wm_video.append(pred_videos_1[:, :, :args.exe_steps].cpu()) wm_video.append(pred_videos_1[:, :, :args.exe_steps].cpu())
profiler.record_memory(f"iter_{itr}_end")
profiler.step_profiler()
full_video = torch.cat(wm_video, dim=2) full_video = torch.cat(wm_video, dim=2)
sample_tag = f"{args.dataset}-vid{sample['videoid']}-wd-fs-{fs}/full"
log_to_tensorboard(writer,
full_video,
sample_tag,
fps=args.save_fps)
sample_full_video_file = f"{video_save_dir}/../{sample['videoid']}_full_fs{fs}.mp4" sample_full_video_file = f"{video_save_dir}/../{sample['videoid']}_full_fs{fs}.mp4"
save_results(full_video, sample_full_video_file, fps=args.save_fps) save_results_async(full_video, sample_full_video_file, fps=args.save_fps)
# Save profiling results # Wait for all async I/O to complete
profiler.save_results() _flush_io()
def get_parser(): def get_parser():
@@ -1511,6 +1160,18 @@ def get_parser():
default="fp32", default="fp32",
help="Dtype for VAE/first_stage_model weights and forward autocast." help="Dtype for VAE/first_stage_model weights and forward autocast."
) )
parser.add_argument(
"--vae_compile",
action='store_true',
default=False,
help="Apply torch.compile to VAE decoder for kernel fusion."
)
parser.add_argument(
"--vae_decode_bs",
type=int,
default=0,
help="VAE decode batch size (0=all frames at once). Reduces kernel launch overhead."
)
parser.add_argument( parser.add_argument(
"--export_casted_ckpt", "--export_casted_ckpt",
type=str, type=str,
@@ -1556,32 +1217,6 @@ def get_parser():
type=int, type=int,
default=8, default=8,
help="fps for the saving video") help="fps for the saving video")
# Profiling arguments
parser.add_argument(
"--profile",
action='store_true',
default=False,
help="Enable performance profiling (macro and operator-level analysis)."
)
parser.add_argument(
"--profile_output_dir",
type=str,
default=None,
help="Directory to save profiling results. Defaults to {savedir}/profile_output."
)
parser.add_argument(
"--profile_iterations",
type=int,
default=3,
help="Number of iterations to run PyTorch profiler's active phase for operator-level analysis."
)
parser.add_argument(
"--profile_detail",
type=str,
choices=["light", "full"],
default="light",
help="Profiling detail level. Use 'full' for shapes/stacks/memory/flops."
)
return parser return parser
@@ -1593,15 +1228,5 @@ if __name__ == '__main__':
seed = random.randint(0, 2**31) seed = random.randint(0, 2**31)
seed_everything(seed) seed_everything(seed)
# Initialize profiler
profile_output_dir = args.profile_output_dir
if profile_output_dir is None:
profile_output_dir = os.path.join(args.savedir, "profile_output")
init_profiler(
enabled=args.profile,
output_dir=profile_output_dir,
profile_detail=args.profile_detail,
)
rank, gpu_num = 0, 1 rank, gpu_num = 0, 1
run_inference(args, gpu_num, rank) run_inference(args, gpu_num, rank)

View File

@@ -99,7 +99,6 @@ class AutoencoderKL(pl.LightningModule):
print(f"Restored from {path}") print(f"Restored from {path}")
def encode(self, x, **kwargs): def encode(self, x, **kwargs):
h = self.encoder(x) h = self.encoder(x)
moments = self.quant_conv(h) moments = self.quant_conv(h)
posterior = DiagonalGaussianDistribution(moments) posterior = DiagonalGaussianDistribution(moments)

View File

@@ -1073,11 +1073,15 @@ class LatentDiffusion(DDPM):
if not self.perframe_ae: if not self.perframe_ae:
encoder_posterior = self.first_stage_model.encode(x) encoder_posterior = self.first_stage_model.encode(x)
results = self.get_first_stage_encoding(encoder_posterior).detach() results = self.get_first_stage_encoding(encoder_posterior).detach()
else: ## Consume less GPU memory but slower else: ## Batch encode with configurable batch size
bs = getattr(self, 'vae_encode_bs', 1)
if bs >= x.shape[0]:
encoder_posterior = self.first_stage_model.encode(x)
results = self.get_first_stage_encoding(encoder_posterior).detach()
else:
results = [] results = []
for index in range(x.shape[0]): for i in range(0, x.shape[0], bs):
frame_batch = self.first_stage_model.encode(x[index:index + frame_batch = self.first_stage_model.encode(x[i:i + bs])
1, :, :, :])
frame_result = self.get_first_stage_encoding( frame_result = self.get_first_stage_encoding(
frame_batch).detach() frame_batch).detach()
results.append(frame_result) results.append(frame_result)
@@ -1105,15 +1109,20 @@ class LatentDiffusion(DDPM):
else: else:
reshape_back = False reshape_back = False
if not self.perframe_ae:
z = 1. / self.scale_factor * z z = 1. / self.scale_factor * z
if not self.perframe_ae:
results = self.first_stage_model.decode(z, **kwargs)
else:
bs = getattr(self, 'vae_decode_bs', 1)
if bs >= z.shape[0]:
# all frames in one batch
results = self.first_stage_model.decode(z, **kwargs) results = self.first_stage_model.decode(z, **kwargs)
else: else:
results = [] results = []
for index in range(z.shape[0]): for i in range(0, z.shape[0], bs):
frame_z = 1. / self.scale_factor * z[index:index + 1, :, :, :] results.append(
frame_result = self.first_stage_model.decode(frame_z, **kwargs) self.first_stage_model.decode(z[i:i + bs], **kwargs))
results.append(frame_result)
results = torch.cat(results, dim=0) results = torch.cat(results, dim=0)
if reshape_back: if reshape_back:

View File

@@ -55,16 +55,13 @@ class DDIMSampler(object):
to_torch(self.model.alphas_cumprod_prev)) to_torch(self.model.alphas_cumprod_prev))
# Calculations for diffusion q(x_t | x_{t-1}) and others # Calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer('sqrt_alphas_cumprod', # Computed directly on GPU to avoid CPU↔GPU transfers
to_torch(np.sqrt(alphas_cumprod.cpu()))) ac = to_torch(alphas_cumprod)
self.register_buffer('sqrt_one_minus_alphas_cumprod', self.register_buffer('sqrt_alphas_cumprod', ac.sqrt())
to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) self.register_buffer('sqrt_one_minus_alphas_cumprod', (1. - ac).sqrt())
self.register_buffer('log_one_minus_alphas_cumprod', self.register_buffer('log_one_minus_alphas_cumprod', (1. - ac).log())
to_torch(np.log(1. - alphas_cumprod.cpu()))) self.register_buffer('sqrt_recip_alphas_cumprod', ac.rsqrt())
self.register_buffer('sqrt_recip_alphas_cumprod', self.register_buffer('sqrt_recipm1_alphas_cumprod', (1. / ac - 1).sqrt())
to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
self.register_buffer('sqrt_recipm1_alphas_cumprod',
to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
# DDIM sampling parameters # DDIM sampling parameters
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters( ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
@@ -86,6 +83,11 @@ class DDIMSampler(object):
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
self.register_buffer('ddim_sqrt_one_minus_alphas', self.register_buffer('ddim_sqrt_one_minus_alphas',
torch.sqrt(1. - ddim_alphas)) torch.sqrt(1. - ddim_alphas))
# Precomputed coefficients for DDIM update formula
self.register_buffer('ddim_sqrt_alphas', ddim_alphas.sqrt())
self.register_buffer('ddim_sqrt_alphas_prev', ddim_alphas_prev.sqrt())
self.register_buffer('ddim_dir_coeff',
(1. - ddim_alphas_prev - ddim_sigmas**2).sqrt())
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) *
(1 - self.alphas_cumprod / self.alphas_cumprod_prev)) (1 - self.alphas_cumprod / self.alphas_cumprod_prev))
@@ -208,18 +210,11 @@ class DDIMSampler(object):
dp_ddim_scheduler_state = self.model.dp_noise_scheduler_state dp_ddim_scheduler_state = self.model.dp_noise_scheduler_state
b = shape[0] b = shape[0]
if x_T is None:
img = torch.randn(shape, device=device)
action = torch.randn((b, 16, self.model.agent_action_dim),
device=device)
state = torch.randn((b, 16, self.model.agent_state_dim),
device=device)
else:
img = x_T
action = torch.randn((b, 16, self.model.agent_action_dim), action = torch.randn((b, 16, self.model.agent_action_dim),
device=device) device=device)
state = torch.randn((b, 16, self.model.agent_state_dim), state = torch.randn((b, 16, self.model.agent_state_dim),
device=device) device=device)
img = torch.randn(shape, device=device) if x_T is None else x_T
if precision is not None: if precision is not None:
if precision == 16: if precision == 16:
@@ -362,12 +357,13 @@ class DDIMSampler(object):
**kwargs) **kwargs)
else: else:
raise NotImplementedError raise NotImplementedError
model_output = e_t_uncond + unconditional_guidance_scale * ( model_output = torch.lerp(e_t_uncond, e_t_cond,
e_t_cond - e_t_uncond) unconditional_guidance_scale)
model_output_action = e_t_uncond_action + unconditional_guidance_scale * ( model_output_action = torch.lerp(e_t_uncond_action,
e_t_cond_action - e_t_uncond_action) e_t_cond_action,
model_output_state = e_t_uncond_state + unconditional_guidance_scale * ( unconditional_guidance_scale)
e_t_cond_state - e_t_uncond_state) model_output_state = torch.lerp(e_t_uncond_state, e_t_cond_state,
unconditional_guidance_scale)
if guidance_rescale > 0.0: if guidance_rescale > 0.0:
model_output = rescale_noise_cfg( model_output = rescale_noise_cfg(
@@ -396,18 +392,28 @@ class DDIMSampler(object):
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
sigmas = self.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas sigmas = self.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
if use_original_steps:
sqrt_alphas = alphas.sqrt()
sqrt_alphas_prev = alphas_prev.sqrt()
dir_coeffs = (1. - alphas_prev - sigmas**2).sqrt()
else:
sqrt_alphas = self.ddim_sqrt_alphas
sqrt_alphas_prev = self.ddim_sqrt_alphas_prev
dir_coeffs = self.ddim_dir_coeff
if is_video: if is_video:
size = (1, 1, 1, 1, 1) size = (1, 1, 1, 1, 1)
else: else:
size = (1, 1, 1, 1) size = (1, 1, 1, 1)
a_t = alphas[index].view(size) sqrt_at = sqrt_alphas[index].view(size)
a_prev = alphas_prev[index].view(size) sqrt_a_prev = sqrt_alphas_prev[index].view(size)
sigma_t = sigmas[index].view(size) sigma_t = sigmas[index].view(size)
sqrt_one_minus_at = sqrt_one_minus_alphas[index].view(size) sqrt_one_minus_at = sqrt_one_minus_alphas[index].view(size)
dir_coeff = dir_coeffs[index].view(size)
if self.model.parameterization != "v": if self.model.parameterization != "v":
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() pred_x0 = (x - sqrt_one_minus_at * e_t) / sqrt_at
else: else:
pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output) pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
@@ -420,14 +426,11 @@ class DDIMSampler(object):
if quantize_denoised: if quantize_denoised:
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
noise = sigma_t * noise_like(x.shape, device, noise = sigma_t * noise_like(x.shape, device,
repeat_noise) * temperature repeat_noise) * temperature
if noise_dropout > 0.: if noise_dropout > 0.:
noise = torch.nn.functional.dropout(noise, p=noise_dropout) noise = torch.nn.functional.dropout(noise, p=noise_dropout)
x_prev = sqrt_a_prev * pred_x0 + dir_coeff * e_t + noise
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
return x_prev, pred_x0, model_output_action, model_output_state return x_prev, pred_x0, model_output_action, model_output_state
@@ -475,7 +478,7 @@ class DDIMSampler(object):
sqrt_alphas_cumprod = self.sqrt_alphas_cumprod sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
else: else:
sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas) sqrt_alphas_cumprod = self.ddim_sqrt_alphas
sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
if noise is None: if noise is None:

View File

@@ -10,8 +10,8 @@ from unifolm_wma.utils.utils import instantiate_from_config
def nonlinearity(x): def nonlinearity(x):
# swish # swish / SiLU — single fused CUDA kernel instead of x * sigmoid(x)
return x * torch.sigmoid(x) return torch.nn.functional.silu(x)
def Normalize(in_channels, num_groups=32): def Normalize(in_channels, num_groups=32):

View File

@@ -7,79 +7,97 @@ MACRO-LEVEL TIMING SUMMARY
---------------------------------------- ----------------------------------------
Section Count Total(ms) Avg(ms) CUDA Avg(ms) Section Count Total(ms) Avg(ms) CUDA Avg(ms)
-------------------------------------------------------------------------------------- --------------------------------------------------------------------------------------
action_generation 11 399707.47 36337.04 36336.85 action_generation 11 173133.54 15739.41 15739.36
data_loading 1 52.85 52.85 52.88 data_loading 1 54.31 54.31 54.34
get_latent_z/encode 22 901.39 40.97 41.01 get_latent_z/encode 22 785.25 35.69 35.72
iteration_total 11 836793.23 76072.11 76071.63 iteration_total 11 386482.08 35134.73 35134.55
load_transitions 1 2.24 2.24 2.28 load_transitions 1 2.07 2.07 2.10
model_loading/checkpoint 1 11833.31 11833.31 11833.43 model_loading/prepared 1 4749.22 4749.22 4749.83
model_loading/config 1 49774.19 49774.19 49774.16 prepare_init_input 1 29.19 29.19 29.22
model_to_cuda 1 8909.30 8909.30 8909.33 prepare_observation 11 5.49 0.50 0.53
prepare_init_input 1 10.52 10.52 10.55 prepare_wm_observation 11 1.93 0.18 0.20
prepare_observation 11 5.41 0.49 0.53 save_results 11 38791.18 3526.47 3526.51
prepare_wm_observation 11 2.12 0.19 0.22 synthesis/conditioning_prep 22 2528.23 114.92 114.95
save_results 11 38668.06 3515.28 3515.32 synthesis/ddim_sampling 22 336003.29 15272.88 15272.83
synthesis/conditioning_prep 22 2916.63 132.57 132.61 synthesis/decode_first_stage 22 9095.14 413.42 413.46
synthesis/ddim_sampling 22 782695.01 35577.05 35576.86 update_action_queues 11 7.28 0.66 0.69
synthesis/decode_first_stage 22 12444.31 565.65 565.70 update_state_queues 11 17.38 1.58 1.61
update_action_queues 11 6.85 0.62 0.65 world_model_interaction 11 174516.52 15865.14 15865.07
update_state_queues 11 17.67 1.61 1.64
world_model_interaction 11 398375.58 36215.96 36215.75
-------------------------------------------------------------------------------------- --------------------------------------------------------------------------------------
TOTAL 2543116.13 TOTAL 1126202.08
---------------------------------------- ----------------------------------------
GPU MEMORY SUMMARY GPU MEMORY SUMMARY
---------------------------------------- ----------------------------------------
Peak allocated: 17890.50 MB Peak allocated: 18188.29 MB
Average allocated: 16129.98 MB Average allocated: 9117.49 MB
---------------------------------------- ----------------------------------------
TOP 30 OPERATORS BY CUDA TIME TOP 30 OPERATORS BY CUDA TIME
---------------------------------------- ----------------------------------------
Operator Count CUDA(ms) CPU(ms) Self CUDA(ms) Operator Count CUDA(ms) CPU(ms) Self CUDA(ms)
------------------------------------------------------------------------------------------------ ------------------------------------------------------------------------------------------------
ProfilerStep* 6 443804.16 237696.98 237689.25 ProfilerStep* 18 690146.23 133688.74 616385.44
aten::linear 171276 112286.23 13179.82 0.00 aten::group_norm 168624 24697.84 29217.27 0.00
aten::addmm 81456 79537.36 3799.84 79296.37 aten::_convolution 96450 21420.26 12845.86 0.00
ampere_sgemm_128x64_tn 26400 52052.10 0.00 52052.10 aten::convolution 96450 21408.68 13480.97 0.00
aten::matmul 90468 34234.05 6281.32 0.00 aten::linear 297398 20780.15 26257.38 0.00
aten::_convolution 100242 33623.79 13105.89 0.00 aten::cudnn_convolution 94638 18660.24 8239.04 18329.28
aten::mm 89820 33580.74 3202.22 33253.18 aten::copy_ 772677 18135.46 17387.09 17864.87
aten::convolution 100242 33575.23 13714.47 0.00 aten::conv3d 52800 12922.42 8572.58 0.00
aten::cudnn_convolution 98430 30932.19 8640.50 29248.12 aten::conv2d 52469 12747.13 7725.70 0.00
ampere_sgemm_32x128_tn 42348 20394.52 0.00 20394.52 aten::native_group_norm 84312 10285.37 8974.31 10197.66
aten::conv2d 42042 18115.35 5932.30 0.00 aten::_to_copy 590277 10270.09 22570.90 0.00
ampere_sgemm_128x32_tn 40938 16429.81 0.00 16429.81 aten::to 602979 9655.26 23666.06 0.00
xformers::efficient_attention_forward_cutlass 24000 15222.23 2532.93 15120.44 aten::conv1d 56245 8174.37 10015.24 0.00
fmha_cutlassF_f32_aligned_64x64_rf_sm80(Attenti... 24000 15121.31 0.00 15121.31 void at::native::(anonymous namespace)::Rowwise... 84312 7979.71 0.00 7979.71
ampere_sgemm_64x64_tn 21000 14627.12 0.00 14627.12 aten::clone 177132 7502.90 7007.48 0.00
aten::copy_ 231819 14504.87 127056.51 14038.39 void cudnn::ops::nchwToNhwcKernel<__nv_bfloat16... 164700 7384.52 0.00 7384.52
aten::group_norm 87144 12033.73 10659.57 0.00 aten::addmm 81456 6958.44 3903.01 6908.44
aten::native_group_norm 87144 11473.40 9449.36 11002.02 aten::layer_norm 65700 5698.92 7816.08 0.00
aten::conv3d 26400 8852.13 3365.43 0.00 void at::native::elementwise_kernel<128, 4, at:... 149688 5372.46 0.00 5372.46
void at::native::(anonymous namespace)::Rowwise... 87144 8714.68 0.00 8714.68 void at::native::unrolled_elementwise_kernel<at... 180120 5165.28 0.00 5165.28
void cudnn::ops::nchwToNhwcKernel<float, float,... 169824 8525.44 0.00 8525.44 ampere_bf16_s16816gemm_bf16_128x128_ldg8_relu_f... 24900 4449.05 0.00 4449.05
aten::clone 214314 8200.26 8568.82 0.00 void at::native::unrolled_elementwise_kernel<at... 368664 4405.30 0.00 4405.30
void at::native::elementwise_kernel<128, 2, at:... 220440 8109.62 0.00 8109.62 aten::reshape 686778 3771.84 8309.51 0.00
void cutlass::Kernel<cutlass_80_simt_sgemm_128x... 15000 7919.30 0.00 7919.30 aten::contiguous 46008 3400.88 1881.73 0.00
aten::_to_copy 12219 5963.43 122411.53 0.00 sm80_xmma_fprop_implicit_gemm_bf16bf16_bf16f32_... 15516 3398.03 0.00 3398.03
aten::to 58101 5952.65 122443.72 0.00 aten::matmul 90489 3366.62 4946.69 0.00
aten::conv1d 30000 5878.95 4556.48 0.00 aten::mm 89820 3284.53 3308.76 3228.56
Memcpy HtoD (Pageable -> Device) 6696 5856.39 0.00 5856.39 void at::native::elementwise_kernel<128, 2, at:... 46518 2441.55 0.00 2441.55
aten::reshape 671772 5124.03 9636.01 0.00 aten::add 113118 2426.66 2776.23 2385.52
sm80_xmma_fprop_implicit_gemm_indexed_tf32f32_t... 16272 5097.70 0.00 5097.70 void at::native::elementwise_kernel<128, 4, at:... 104550 2426.41 0.00 2426.41
---------------------------------------- ----------------------------------------
OPERATOR CATEGORY BREAKDOWN OPERATOR CATEGORY BREAKDOWN
---------------------------------------- ----------------------------------------
Category CUDA Time(ms) Percentage Category CUDA Time(ms) Percentage
--------------------------------------------------------- ---------------------------------------------------------
Other 481950.47 41.9% Other 723472.91 71.9%
Linear/GEMM 342333.09 29.8% Convolution 114469.81 11.4%
Convolution 159920.77 13.9% Memory 53845.46 5.4%
Elementwise 54682.93 4.8% Normalization 46852.57 4.7%
Memory 36883.36 3.2% Linear/GEMM 35354.58 3.5%
Attention 34736.13 3.0% Elementwise 17078.44 1.7%
Normalization 32081.19 2.8% Activation 12296.29 1.2%
Activation 6449.19 0.6% Attention 2956.61 0.3%
------------------------------------------------------------------------------------------
aten::addmm (Linear/GEMM) UTILIZATION ANALYSIS ON A100
------------------------------------------------------------------------------------------
Effective compute precision: BF16 Tensor Core (312 TFLOPS)
torch.backends.cuda.matmul.allow_tf32 = False
Metric Value
-------------------------------------------------------------------
Total aten::addmm calls 81,456
Total Self CUDA time 6908.44 ms
Total FLOPs (profiler) 1.33 PFLOPS
Achieved throughput 191.88 TFLOPS/s
A100 peak throughput 312.00 TFLOPS/s
MFU (Model FLOPs Utilization) 61.50%
INTERPRETATION:
-------------------------------------------------------------------
Good utilization (>60%). GEMM kernels are compute-bound
and running efficiently on Tensor Cores.

View File

@@ -2,7 +2,7 @@ res_dir="unitree_g1_pack_camera/case1"
dataset="unitree_g1_pack_camera" dataset="unitree_g1_pack_camera"
{ {
time CUDA_VISIBLE_DEVICES=1 python3 scripts/evaluation/world_model_interaction.py \ time CUDA_VISIBLE_DEVICES=0 python3 scripts/evaluation/world_model_interaction.py \
--seed 123 \ --seed 123 \
--ckpt_path ckpts/unifolm_wma_dual_mixbf16.ckpt \ --ckpt_path ckpts/unifolm_wma_dual_mixbf16.ckpt \
--config configs/inference/world_model_interaction.yaml \ --config configs/inference/world_model_interaction.yaml \
@@ -23,5 +23,6 @@ dataset="unitree_g1_pack_camera"
--perframe_ae \ --perframe_ae \
--diffusion_dtype bf16 \ --diffusion_dtype bf16 \
--projector_mode bf16_full \ --projector_mode bf16_full \
--encoder_mode bf16_full --encoder_mode bf16_full \
--vae_dtype bf16
} 2>&1 | tee "${res_dir}/output.log" } 2>&1 | tee "${res_dir}/output.log"