Compare commits
10 Commits
7e501b17fd
...
qhy
| Author | SHA1 | Date | |
|---|---|---|---|
| bb274870c2 | |||
| f1f92072e6 | |||
| ff920b85a2 | |||
| 6630952d2b | |||
| bc78815acf | |||
| d5f6577fa8 | |||
| 7dcf9e8b89 | |||
| aba2a90045 | |||
| 25de36b9bc | |||
| 2fdcec6da0 |
245
README.md
245
README.md
@@ -1,228 +1,29 @@
|
||||
# UnifoLM-WMA-0: A World-Model-Action (WMA) Framework under UnifoLM Family
|
||||
<p style="font-size: 1.2em;">
|
||||
<a href="https://unigen-x.github.io/unifolm-world-model-action.github.io"><strong>Project Page</strong></a> |
|
||||
<a href="https://huggingface.co/collections/unitreerobotics/unifolm-wma-0-68ca23027310c0ca0f34959c"><strong>Models</strong></a> |
|
||||
<a href="https://huggingface.co/unitreerobotics/datasets"><strong>Dataset</strong></a>
|
||||
</p>
|
||||
<div align="center">
|
||||
<p align="right">
|
||||
<span> 🌎English </span> | <a href="README_cn.md"> 🇨🇳中文 </a>
|
||||
</p>
|
||||
</div>
|
||||
<div align="justify">
|
||||
<b>UnifoLM-WMA-0</b> is Unitree‘s open-source world-model–action architecture spanning multiple types of robotic embodiments, designed specifically for general-purpose robot learning. Its core component is a world-model capable of understanding the physical interactions between robots and the environments. This world-model provides two key functions: (a) <b>Simulation Engine</b> – operates as an interactive simulator to generate synthetic data for robot learning; (b) <b>Policy Enhancement</b> – connects with an action head and, by predicting future interaction processes with the world-model, further optimizes decision-making performance.
|
||||
</div>
|
||||
# World Model Interaction 混合精度加速记录(case1)
|
||||
|
||||
## 🦾 Real-Robot Demonstrations
|
||||
| <img src="assets/gifs/real_z1_stackbox.gif" style="border:none;box-shadow:none;margin:0;padding:0;" /> | <img src="assets/gifs/real_dual_stackbox.gif" style="border:none;box-shadow:none;margin:0;padding:0;" /> |
|
||||
|:---:|:---:|
|
||||
| <img src="assets/gifs/real_cleanup_pencils.gif" style="border:none;box-shadow:none;margin:0;padding:0;" /> | <img src="assets/gifs/real_g1_pack_camera.gif" style="border:none;box-shadow:none;margin:0;padding:0;" /> |
|
||||
## 变更位置
|
||||
- 脚本路径:`/home/dyz/unifolm-world-model-action/unitree_g1_pack_camera/case1/run_world_model_interaction.sh`
|
||||
- 当前状态:已修改了部分原本不建议修改/需要谨慎修改的参数(后续会在确认最优后固化为默认)。
|
||||
|
||||
**Note: the top-right window shows the world model’s pretion of future action videos.**
|
||||
## 新增参数(确认最优后可变为默认)
|
||||
- `--diffusion_dtype {fp32,bf16}`:Diffusion 权重与前向 dtype,默认 `fp32`
|
||||
- `--projector_mode {fp32,autocast,bf16_full}`:Projector 精度策略,默认 `fp32`
|
||||
- `--encoder_mode {fp32,autocast,bf16_full}`:Encoder 精度策略,默认 `fp32`
|
||||
- `--vae_dtype {fp32,bf16}`:VAE 权重与前向 dtype,默认 `fp32`
|
||||
- `--export_casted_ckpt <path>`:按当前精度设置导出 ckpt(用于离线导出混合精度权重)
|
||||
- `--export_only`:只导出 ckpt 后退出,默认关闭
|
||||
|
||||
## 🔥 News
|
||||
### 参数语义约定
|
||||
- `fp32`:权重 + 前向均使用 fp32
|
||||
- `autocast`:权重保持 fp32,forward 在 `torch.autocast` 下运行(算子级混精)
|
||||
- `bf16_full`:权重显式转换为 bf16,forward 也以 bf16 为主
|
||||
|
||||
* Sep 22, 2025: 🚀 We released the deployment code for assisting experiments with [Unitree](https://www.unitree.com/) robots.
|
||||
* Sep 15, 2025: 🚀 We released the training and inference code along with the model weights of [**UnifoLM-WMA-0**](https://huggingface.co/collections/unitreerobotics/unifolm-wma-0-68ca23027310c0ca0f34959c).
|
||||
## 当前最优配置与结果
|
||||
### 配置
|
||||
- 除 VAE 模块外,其它模块全部 bf16
|
||||
- 模型离线导出混合精度 ckpt(使用 `--export_casted_ckpt`)
|
||||
|
||||
## 📑 Opensource Plan
|
||||
- [x] Training
|
||||
- [x] Inference
|
||||
- [x] Checkpoints
|
||||
- [x] Deployment
|
||||
### 结果
|
||||
- 耗时:从 `15m6s` 降到 `7m5s`
|
||||
- PSNR:下降不到 `4`(`35 -> 31`)
|
||||
- 显存:占用降到原本约 `50%`
|
||||
|
||||
## ⚙️ Installation
|
||||
```
|
||||
conda create -n unifolm-wma python==3.10.18
|
||||
conda activate unifolm-wma
|
||||
|
||||
conda install pinocchio=3.2.0 -c conda-forge -y
|
||||
conda install ffmpeg=7.1.1 -c conda-forge
|
||||
|
||||
git clone --recurse-submodules https://github.com/unitreerobotics/unifolm-world-model-action.git
|
||||
|
||||
# If you already downloaded the repo:
|
||||
cd unifolm-world-model-action
|
||||
git submodule update --init --recursive
|
||||
|
||||
pip install -e .
|
||||
|
||||
cd external/dlimp
|
||||
pip install -e .
|
||||
```
|
||||
## 🧰 Model Checkpoints
|
||||
| Model | Description | Link|
|
||||
|---------|-------|------|
|
||||
|$\text{UnifoLM-WMA-0}_{Base}$| Fine-tuned on [Open-X](https://robotics-transformer-x.github.io/) dataset. | [HuggingFace](https://huggingface.co/unitreerobotics/UnifoLM-WMA-0-Base)|
|
||||
|$\text{UnifoLM-WMA-0}_{Dual}$| Fine-tuned on five [Unitree opensource dataset](https://huggingface.co/collections/unitreerobotics/g1-dex1-datasets-68bae98bf0a26d617f9983ab) in both decision-making and simulation modes. | [HuggingFace](https://huggingface.co/unitreerobotics/UnifoLM-WMA-0-Dual)|
|
||||
|
||||
## 🛢️ Dataset
|
||||
In our experiments, we consider the following three opensource dataset:
|
||||
| Dataset | Robot | Link |
|
||||
|---------|-------|------|
|
||||
|Z1_StackBox| [Unitree Z1](https://www.unitree.com/z1)|[Huggingface](https://huggingface.co/datasets/unitreerobotics/Z1_StackBox_Dataset/tree/v2.1)|
|
||||
|Z1_DualArm_StackBox|[Unitree Z1](https://www.unitree.com/z1)|[Huggingface](https://huggingface.co/datasets/unitreerobotics/Z1_Dual_Dex1_StackBox_Dataset/tree/v2.1)|
|
||||
|Z1_DualArm_StackBox_V2|[Unitree Z1](https://www.unitree.com/z1)|[Huggingface](https://huggingface.co/datasets/unitreerobotics/Z1_Dual_Dex1_StackBox_Dataset_V2/tree/v2.1)|
|
||||
|Z1_DualArm_Cleanup_Pencils|[Unitree Z1](https://www.unitree.com/z1)|[Huggingface](https://huggingface.co/datasets/unitreerobotics/Z1_Dual_Dex1_CleanupPencils_Dataset/tree/v2.1)|
|
||||
|G1_Pack_Camera|[Unitree G1](https://www.unitree.com/g1)|[Huggingface](https://huggingface.co/datasets/unitreerobotics/G1_Dex1_MountCameraRedGripper_Dataset/tree/v2.1)|
|
||||
|
||||
To train on your own dataset, first to have the data following the [Huggingface LeRobot V2.1](https://github.com/huggingface/lerobot) dataset format. Assume the dataset’s source directory structure is as follows:
|
||||
```
|
||||
source_dir/
|
||||
├── dataset1_name
|
||||
├── dataset2_name
|
||||
├── dataset3_name
|
||||
└── ...
|
||||
```
|
||||
Then, convert a dataset to the required format using the command below:
|
||||
```python
|
||||
cd prepare_data
|
||||
python prepare_training_data.py \
|
||||
--source_dir /path/to/your/source_dir \
|
||||
--target_dir /path/to/save/the/converted/data \
|
||||
--dataset_name "dataset1_name" \
|
||||
--robot_name "a tag of the robot in the dataset" # e.g, Unitree Z1 Robot Arm or Unitree G1 Robot with Gripper.
|
||||
```
|
||||
The resulting data structure (Note: model training only supports input from the main-view camera. If the dataset includes multiple views, remove the corresponding values from the ```data_dir``` column in the CSV file.
|
||||
```
|
||||
target_dir/
|
||||
├── videos
|
||||
│ ├──dataset1_name
|
||||
│ │ ├──camera_view_dir
|
||||
│ │ ├── 0.mp4
|
||||
│ │ ├── 1.mp4
|
||||
│ │ └── ...
|
||||
│ └── ...
|
||||
├── transitions
|
||||
│ ├── dataset1_name
|
||||
│ ├── meta_data
|
||||
│ ├── 0.h5
|
||||
│ ├── 1.h5
|
||||
│ └── ...
|
||||
└── dataset1_name.csv
|
||||
```
|
||||
## 🚴♂️ Training
|
||||
A. Our training strategy is outlined as follows:
|
||||
- **Step 1**: Fine-tune a video generation model as the world model using the [Open-X](https://robotics-transformer-x.github.io/) dataset;
|
||||
- **Step 2**: Post-train $\text{UnifoLM-WMA}$ in decision-making mode on the downstream task dataset;
|
||||
<div align="left">
|
||||
<img src="assets/pngs/dm_mode.png" width="600">
|
||||
</div>
|
||||
- **Step 3**: Post-train $\text{UnifoLM-WMA}$ in simulation mode on the downstream task dataset.
|
||||
<div align="left">
|
||||
<img src="assets/pngs/sim_mode.png" width="600">
|
||||
</div>
|
||||
**Note**: If you only require $\text{UnifoLM-WMA}$ to operate in a single mode, you may skip the corresponding step.
|
||||
|
||||
B. To conduct training on a single or multiple datasets, please follow the steps below:
|
||||
- **Step 1**: The maximum DoF is assumed to be 16, if you have more than 16 DoF, update ```agent_state_dim``` and ```agent_action_dim``` in [configs/train/config.yaml](https://github.com/unitreerobotics/unifolm-wma/blob/working/configs/train/config.yaml) ;
|
||||
- **Step 2**: Set up the input shapes for each modality in [configs/train/meta.json](https://github.com/unitreerobotics/unitree-world-model/blob/main/configs/train/meta.json);
|
||||
- **Step 3**: Configure the training parameters in [configs/train/config.yaml](https://github.com/unitreerobotics/unitree-world-model/blob/main/configs/train/config.yaml). For the ```pretrained_checkpoint```, we recommend using the checkpoint " $\text{UnifoLM-WMA-0}_{Base}$ " fine-tuned on the [Open-X](https://robotics-transformer-x.github.io/) dataset;
|
||||
```yaml
|
||||
model:
|
||||
pretrained_checkpoint: /path/to/pretrained/checkpoint;
|
||||
...
|
||||
decision_making_only: True # Train the world model only in decision-making mode. If False, jointly train it in both decision-making and simulation modes.
|
||||
...
|
||||
data:
|
||||
...
|
||||
train:
|
||||
...
|
||||
data_dir: /path/to/training/dataset/directory
|
||||
dataset_and_weights: # list the name of each dataset below and make sure the summation of weights is 1.0
|
||||
dataset1_name: 0.2
|
||||
dataset2_name: 0.2
|
||||
dataset3_name: 0.2
|
||||
dataset4_name: 0.2
|
||||
dataset5_name: 0.2
|
||||
```
|
||||
- **Step 4**: Setup ```experiment_name```, ```save_root``` variables in [scripts/train.sh](https://github.com/unitreerobotics/unitree-world-model/blob/main/scripts/train.sh);
|
||||
- **Step 5**: Launch the training with the command:
|
||||
```
|
||||
bash scripts/train.sh
|
||||
```
|
||||
## 🌏 Inference under Interactive Simulation Mode
|
||||
To run the world model in an interactive simulation mode, follow these steps:
|
||||
- **Step 1**: (Skip this step if you just would like to test using the examples we provided) Prepare your own prompt following the format used in the [examples/world_model_interaction_prompts](https://github.com/unitreerobotics/unitree-world-model/tree/main/examples/world_model_interaction_prompts):
|
||||
```
|
||||
world_model_interaction_prompts/
|
||||
├── images
|
||||
│ ├── dataset1_name
|
||||
│ │ ├── 0.png # Image prompt
|
||||
│ │ └── ...
|
||||
│ └── ...
|
||||
├── transitions
|
||||
│ ├── dataset1_name
|
||||
│ │ ├── meta_data # Used for normalization
|
||||
│ │ ├── 0.h # Robot state and action data; in interaction mode,
|
||||
│ │ │ # only used to retrieve the robot state corresponding
|
||||
│ │ │ # to the image prompt
|
||||
│ │ └── ...
|
||||
│ └── ...
|
||||
├── dataset1_name.csv # File for loading image prompts, text instruction and corresponding robot states
|
||||
└── ...
|
||||
```
|
||||
- **Step 2**: Specify the correct paths for ```pretrained_checkpoint```(e.g, $\text{UnifoLM-WMA-0}_{Dual}$) and ```data_dir``` in [configs/inference/world_model_interaction.yaml](https://github.com/unitreerobotics/unitree-world-model/blob/main/configs/inference/world_model_interaction.yaml)
|
||||
- **Step 3**: Set the paths for ```checkpoint```, ```res_dir``` and ```prompt_dir``` in [scripts/run_world_model_interaction.sh](https://github.com/unitreerobotics/unitree-world-model/blob/main/scripts/run_world_model_interaction.sh), and specify all the dataset's name in ```datasets=(...)```. Then, launch the inference with the command:
|
||||
```
|
||||
bash scripts/run_world_model_interaction.sh
|
||||
```
|
||||
|
||||
## 🧠 Inference and Deployment under Decision-Making Mode
|
||||
|
||||
In this setup, inference is performed on a server, while a robot client gathers observations from the real-robot and sends them to the server to query actions. The process unfolds through the following steps:
|
||||
|
||||
### Server Setup:
|
||||
- **Step-1**: Specify ```ckpt```, ```res_dir```, ```datasets``` in [scripts/run_real_eval_server.sh](https://github.com/unitreerobotics/unifolm-world-model-action/blob/main/scripts/run_real_eval_server.sh);
|
||||
- **Step-2**: Configure ```data_dir``` and ```dataset_and_weights``` in [config/inference/world_model_decision_making.yaml](https://github.com/unitreerobotics/unifolm-world-model-action/blob/f12b4782652ca00452941d851b17446e4ee7124a/configs/inference/world_model_decision_making.yaml#L225);
|
||||
- **Step-3**: Launch the server:
|
||||
```
|
||||
conda activate unifolm-wma
|
||||
cd unifolm-world-model-action
|
||||
bash scripts/run_real_eval_server.sh
|
||||
```
|
||||
|
||||
### Client Setup
|
||||
- **Step-1**: Follow the instructions in [unitree_deploy/README.md](https://github.com/unitreerobotics/unifolm-world-model-action/blob/main/unitree_deploy/README.md) to create the ```unitree_deploy``` conda environment, install the required packages, launch the controllers or services on the real-robot.
|
||||
- **Step-2**: Open a new terminal and establish a tunnel connection from the client to the server:
|
||||
```
|
||||
ssh user_name@remote_server_IP -CNg -L 8000:127.0.0.1:8000
|
||||
```
|
||||
- **Step-3**: Run the ```unitree_deploy/robot_client.py``` script to start inference:
|
||||
```
|
||||
cd unitree_deploy
|
||||
python scripts/robot_client.py --robot_type "g1_dex1" --action_horizon 16 --exe_steps 16 --observation_horizon 2 --language_instruction "pack black camera into box" --output_dir ./results --control_freq 15
|
||||
```
|
||||
|
||||
## 📝 Codebase Architecture
|
||||
Here's a high-level overview of the project's code structure and core components:
|
||||
```
|
||||
unitree-world-model/
|
||||
├── assets # Media assets such as GIFs, images, and demo videos
|
||||
├── configs # Configuration files for training and inference
|
||||
│ ├── inference
|
||||
│ └── train
|
||||
├── examples # Example inputs and prompts for running inference
|
||||
├── external # External packages
|
||||
├── prepare_data # Scripts for dataset preprocessing and format conversion
|
||||
├── scripts # Main scripts for training, evaluation, and deployment
|
||||
├── src
|
||||
│ ├──unitree_worldmodel # Core Python package for the Unitree world model
|
||||
│ │ ├── data # Dataset loading, transformations, and dataloaders
|
||||
│ │ ├── models # Model architectures and backbone definitions
|
||||
│ │ ├── modules # Custom model modules and components
|
||||
│ │ └── utils # Utility functions and common helpers
|
||||
└── unitree_deploy # Deployment code
|
||||
```
|
||||
|
||||
## 🙏 Acknowledgement
|
||||
Lots of code are inherited from [DynamiCrafter](https://github.com/Doubiiu/DynamiCrafter), [Diffusion Policy](https://github.com/real-stanford/diffusion_policy), [ACT](https://github.com/MarkFzp/act-plus-plus) and [HPT](https://github.com/liruiw/HPT).
|
||||
|
||||
## 📝 Citation
|
||||
```
|
||||
@misc{unifolm-wma-0,
|
||||
author = {Unitree},
|
||||
title = {UnifoLM-WMA-0: A World-Model-Action (WMA) Framework under UnifoLM Family},
|
||||
year = {2025},
|
||||
}
|
||||
```
|
||||
|
||||
@@ -222,7 +222,7 @@ data:
|
||||
test:
|
||||
target: unifolm_wma.data.wma_data.WMAData
|
||||
params:
|
||||
data_dir: '/home/dyz/unifolm-world-model-action/examples/world_model_interaction_prompts'
|
||||
data_dir: '/home/qhy/unifolm-world-model-action/examples/world_model_interaction_prompts'
|
||||
video_length: ${model.params.wma_config.params.temporal_length}
|
||||
frame_stride: 2
|
||||
load_raw_resolution: True
|
||||
|
||||
@@ -9,10 +9,9 @@ import logging
|
||||
import einops
|
||||
import warnings
|
||||
import imageio
|
||||
import time
|
||||
import json
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from dataclasses import dataclass, field, asdict
|
||||
import atexit
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from contextlib import nullcontext
|
||||
from typing import Optional, Dict, List, Any, Mapping
|
||||
|
||||
from pytorch_lightning import seed_everything
|
||||
@@ -21,376 +20,66 @@ from tqdm import tqdm
|
||||
from einops import rearrange, repeat
|
||||
from collections import OrderedDict
|
||||
from torch import nn
|
||||
from eval_utils import populate_queues, log_to_tensorboard
|
||||
from eval_utils import populate_queues
|
||||
from collections import deque
|
||||
from torch import Tensor
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from PIL import Image
|
||||
|
||||
from unifolm_wma.models.samplers.ddim import DDIMSampler
|
||||
from unifolm_wma.utils.utils import instantiate_from_config
|
||||
|
||||
|
||||
# ========== Profiling Infrastructure ==========
|
||||
@dataclass
|
||||
class TimingRecord:
|
||||
"""Record for a single timing measurement."""
|
||||
name: str
|
||||
start_time: float = 0.0
|
||||
end_time: float = 0.0
|
||||
cuda_time_ms: float = 0.0
|
||||
count: int = 0
|
||||
children: List['TimingRecord'] = field(default_factory=list)
|
||||
|
||||
@property
|
||||
def cpu_time_ms(self) -> float:
|
||||
return (self.end_time - self.start_time) * 1000
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
'name': self.name,
|
||||
'cpu_time_ms': self.cpu_time_ms,
|
||||
'cuda_time_ms': self.cuda_time_ms,
|
||||
'count': self.count,
|
||||
'children': [c.to_dict() for c in self.children]
|
||||
}
|
||||
# ========== Async I/O ==========
|
||||
_io_executor: Optional[ThreadPoolExecutor] = None
|
||||
_io_futures: List[Any] = []
|
||||
|
||||
|
||||
class ProfilerManager:
|
||||
"""Manages macro and micro-level profiling."""
|
||||
def _get_io_executor() -> ThreadPoolExecutor:
|
||||
global _io_executor
|
||||
if _io_executor is None:
|
||||
_io_executor = ThreadPoolExecutor(max_workers=2)
|
||||
return _io_executor
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
enabled: bool = False,
|
||||
output_dir: str = "./profile_output",
|
||||
profile_detail: str = "light",
|
||||
):
|
||||
self.enabled = enabled
|
||||
self.output_dir = output_dir
|
||||
self.profile_detail = profile_detail
|
||||
self.macro_timings: Dict[str, List[float]] = {}
|
||||
self.cuda_events: Dict[str, List[tuple]] = {}
|
||||
self.memory_snapshots: List[Dict] = []
|
||||
self.pytorch_profiler = None
|
||||
self.current_iteration = 0
|
||||
self.operator_stats: Dict[str, Dict] = {}
|
||||
self.profiler_config = self._build_profiler_config(profile_detail)
|
||||
|
||||
if enabled:
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
def _build_profiler_config(self, profile_detail: str) -> Dict[str, Any]:
|
||||
"""Return profiler settings based on the requested detail level."""
|
||||
if profile_detail not in ("light", "full"):
|
||||
raise ValueError(f"Unsupported profile_detail: {profile_detail}")
|
||||
if profile_detail == "full":
|
||||
return {
|
||||
"record_shapes": True,
|
||||
"profile_memory": True,
|
||||
"with_stack": True,
|
||||
"with_flops": True,
|
||||
"with_modules": True,
|
||||
"group_by_input_shape": True,
|
||||
}
|
||||
return {
|
||||
"record_shapes": False,
|
||||
"profile_memory": False,
|
||||
"with_stack": False,
|
||||
"with_flops": False,
|
||||
"with_modules": False,
|
||||
"group_by_input_shape": False,
|
||||
}
|
||||
|
||||
@contextmanager
|
||||
def profile_section(self, name: str, sync_cuda: bool = True):
|
||||
"""Context manager for profiling a code section."""
|
||||
if not self.enabled:
|
||||
yield
|
||||
return
|
||||
|
||||
if sync_cuda and torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
|
||||
start_event = None
|
||||
end_event = None
|
||||
if torch.cuda.is_available():
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
start_event.record()
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
def _flush_io():
|
||||
"""Wait for all pending async I/O to finish."""
|
||||
global _io_futures
|
||||
for fut in _io_futures:
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if sync_cuda and torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
|
||||
end_time = time.perf_counter()
|
||||
cpu_time_ms = (end_time - start_time) * 1000
|
||||
|
||||
cuda_time_ms = 0.0
|
||||
if start_event is not None and end_event is not None:
|
||||
end_event.record()
|
||||
torch.cuda.synchronize()
|
||||
cuda_time_ms = start_event.elapsed_time(end_event)
|
||||
|
||||
if name not in self.macro_timings:
|
||||
self.macro_timings[name] = []
|
||||
self.macro_timings[name].append(cpu_time_ms)
|
||||
|
||||
if name not in self.cuda_events:
|
||||
self.cuda_events[name] = []
|
||||
self.cuda_events[name].append((cpu_time_ms, cuda_time_ms))
|
||||
|
||||
def record_memory(self, tag: str = ""):
|
||||
"""Record current GPU memory state."""
|
||||
if not self.enabled or not torch.cuda.is_available():
|
||||
return
|
||||
|
||||
snapshot = {
|
||||
'tag': tag,
|
||||
'iteration': self.current_iteration,
|
||||
'allocated_mb': torch.cuda.memory_allocated() / 1024**2,
|
||||
'reserved_mb': torch.cuda.memory_reserved() / 1024**2,
|
||||
'max_allocated_mb': torch.cuda.max_memory_allocated() / 1024**2,
|
||||
}
|
||||
self.memory_snapshots.append(snapshot)
|
||||
|
||||
def start_pytorch_profiler(self, wait: int = 1, warmup: int = 1, active: int = 3):
|
||||
"""Start PyTorch profiler for operator-level analysis."""
|
||||
if not self.enabled:
|
||||
return nullcontext()
|
||||
|
||||
self.pytorch_profiler = torch.profiler.profile(
|
||||
activities=[
|
||||
torch.profiler.ProfilerActivity.CPU,
|
||||
torch.profiler.ProfilerActivity.CUDA,
|
||||
],
|
||||
schedule=torch.profiler.schedule(
|
||||
wait=wait, warmup=warmup, active=active, repeat=1
|
||||
),
|
||||
on_trace_ready=self._trace_handler,
|
||||
record_shapes=self.profiler_config["record_shapes"],
|
||||
profile_memory=self.profiler_config["profile_memory"],
|
||||
with_stack=self.profiler_config["with_stack"],
|
||||
with_flops=self.profiler_config["with_flops"],
|
||||
with_modules=self.profiler_config["with_modules"],
|
||||
)
|
||||
return self.pytorch_profiler
|
||||
|
||||
def _trace_handler(self, prof):
|
||||
"""Handle profiler trace output."""
|
||||
trace_path = os.path.join(
|
||||
self.output_dir,
|
||||
f"trace_iter_{self.current_iteration}.json"
|
||||
)
|
||||
prof.export_chrome_trace(trace_path)
|
||||
|
||||
# Extract operator statistics
|
||||
key_averages = prof.key_averages(
|
||||
group_by_input_shape=self.profiler_config["group_by_input_shape"]
|
||||
)
|
||||
for evt in key_averages:
|
||||
op_name = evt.key
|
||||
if op_name not in self.operator_stats:
|
||||
self.operator_stats[op_name] = {
|
||||
'count': 0,
|
||||
'cpu_time_total_us': 0,
|
||||
'cuda_time_total_us': 0,
|
||||
'self_cpu_time_total_us': 0,
|
||||
'self_cuda_time_total_us': 0,
|
||||
'cpu_memory_usage': 0,
|
||||
'cuda_memory_usage': 0,
|
||||
'flops': 0,
|
||||
}
|
||||
stats = self.operator_stats[op_name]
|
||||
stats['count'] += evt.count
|
||||
stats['cpu_time_total_us'] += evt.cpu_time_total
|
||||
stats['cuda_time_total_us'] += evt.cuda_time_total
|
||||
stats['self_cpu_time_total_us'] += evt.self_cpu_time_total
|
||||
stats['self_cuda_time_total_us'] += evt.self_cuda_time_total
|
||||
if hasattr(evt, 'cpu_memory_usage'):
|
||||
stats['cpu_memory_usage'] += evt.cpu_memory_usage
|
||||
if hasattr(evt, 'cuda_memory_usage'):
|
||||
stats['cuda_memory_usage'] += evt.cuda_memory_usage
|
||||
if hasattr(evt, 'flops') and evt.flops:
|
||||
stats['flops'] += evt.flops
|
||||
|
||||
def step_profiler(self):
|
||||
"""Step the PyTorch profiler."""
|
||||
if self.pytorch_profiler is not None:
|
||||
self.pytorch_profiler.step()
|
||||
|
||||
def generate_report(self) -> str:
|
||||
"""Generate comprehensive profiling report."""
|
||||
if not self.enabled:
|
||||
return "Profiling disabled."
|
||||
|
||||
report_lines = []
|
||||
report_lines.append("=" * 80)
|
||||
report_lines.append("PERFORMANCE PROFILING REPORT")
|
||||
report_lines.append("=" * 80)
|
||||
report_lines.append("")
|
||||
|
||||
# Macro-level timing summary
|
||||
report_lines.append("-" * 40)
|
||||
report_lines.append("MACRO-LEVEL TIMING SUMMARY")
|
||||
report_lines.append("-" * 40)
|
||||
report_lines.append(f"{'Section':<40} {'Count':>8} {'Total(ms)':>12} {'Avg(ms)':>12} {'CUDA Avg(ms)':>14}")
|
||||
report_lines.append("-" * 86)
|
||||
|
||||
total_time = 0
|
||||
timing_data = []
|
||||
for name, times in sorted(self.macro_timings.items()):
|
||||
cuda_times = [ct for _, ct in self.cuda_events.get(name, [])]
|
||||
avg_time = np.mean(times)
|
||||
avg_cuda = np.mean(cuda_times) if cuda_times else 0
|
||||
total = sum(times)
|
||||
total_time += total
|
||||
timing_data.append({
|
||||
'name': name,
|
||||
'count': len(times),
|
||||
'total_ms': total,
|
||||
'avg_ms': avg_time,
|
||||
'cuda_avg_ms': avg_cuda,
|
||||
'times': times,
|
||||
'cuda_times': cuda_times,
|
||||
})
|
||||
report_lines.append(f"{name:<40} {len(times):>8} {total:>12.2f} {avg_time:>12.2f} {avg_cuda:>14.2f}")
|
||||
|
||||
report_lines.append("-" * 86)
|
||||
report_lines.append(f"{'TOTAL':<40} {'':<8} {total_time:>12.2f}")
|
||||
report_lines.append("")
|
||||
|
||||
# Memory summary
|
||||
if self.memory_snapshots:
|
||||
report_lines.append("-" * 40)
|
||||
report_lines.append("GPU MEMORY SUMMARY")
|
||||
report_lines.append("-" * 40)
|
||||
max_alloc = max(s['max_allocated_mb'] for s in self.memory_snapshots)
|
||||
avg_alloc = np.mean([s['allocated_mb'] for s in self.memory_snapshots])
|
||||
report_lines.append(f"Peak allocated: {max_alloc:>10.2f} MB")
|
||||
report_lines.append(f"Average allocated: {avg_alloc:>10.2f} MB")
|
||||
report_lines.append("")
|
||||
|
||||
# Top operators by CUDA time
|
||||
if self.operator_stats:
|
||||
report_lines.append("-" * 40)
|
||||
report_lines.append("TOP 30 OPERATORS BY CUDA TIME")
|
||||
report_lines.append("-" * 40)
|
||||
sorted_ops = sorted(
|
||||
self.operator_stats.items(),
|
||||
key=lambda x: x[1]['cuda_time_total_us'],
|
||||
reverse=True
|
||||
)[:30]
|
||||
|
||||
report_lines.append(f"{'Operator':<50} {'Count':>8} {'CUDA(ms)':>12} {'CPU(ms)':>12} {'Self CUDA(ms)':>14}")
|
||||
report_lines.append("-" * 96)
|
||||
|
||||
for op_name, stats in sorted_ops:
|
||||
# Truncate long operator names
|
||||
display_name = op_name[:47] + "..." if len(op_name) > 50 else op_name
|
||||
report_lines.append(
|
||||
f"{display_name:<50} {stats['count']:>8} "
|
||||
f"{stats['cuda_time_total_us']/1000:>12.2f} "
|
||||
f"{stats['cpu_time_total_us']/1000:>12.2f} "
|
||||
f"{stats['self_cuda_time_total_us']/1000:>14.2f}"
|
||||
)
|
||||
report_lines.append("")
|
||||
|
||||
# Compute category breakdown
|
||||
report_lines.append("-" * 40)
|
||||
report_lines.append("OPERATOR CATEGORY BREAKDOWN")
|
||||
report_lines.append("-" * 40)
|
||||
|
||||
categories = {
|
||||
'Attention': ['attention', 'softmax', 'bmm', 'baddbmm'],
|
||||
'Convolution': ['conv', 'cudnn'],
|
||||
'Normalization': ['norm', 'layer_norm', 'batch_norm', 'group_norm'],
|
||||
'Activation': ['relu', 'gelu', 'silu', 'sigmoid', 'tanh'],
|
||||
'Linear/GEMM': ['linear', 'addmm', 'mm', 'gemm'],
|
||||
'Memory': ['copy', 'contiguous', 'view', 'reshape', 'permute', 'transpose'],
|
||||
'Elementwise': ['add', 'mul', 'div', 'sub', 'pow', 'exp', 'sqrt'],
|
||||
}
|
||||
|
||||
category_times = {cat: 0.0 for cat in categories}
|
||||
category_times['Other'] = 0.0
|
||||
|
||||
for op_name, stats in self.operator_stats.items():
|
||||
op_lower = op_name.lower()
|
||||
categorized = False
|
||||
for cat, keywords in categories.items():
|
||||
if any(kw in op_lower for kw in keywords):
|
||||
category_times[cat] += stats['cuda_time_total_us']
|
||||
categorized = True
|
||||
break
|
||||
if not categorized:
|
||||
category_times['Other'] += stats['cuda_time_total_us']
|
||||
|
||||
total_op_time = sum(category_times.values())
|
||||
report_lines.append(f"{'Category':<30} {'CUDA Time(ms)':>15} {'Percentage':>12}")
|
||||
report_lines.append("-" * 57)
|
||||
for cat, time_us in sorted(category_times.items(), key=lambda x: -x[1]):
|
||||
pct = (time_us / total_op_time * 100) if total_op_time > 0 else 0
|
||||
report_lines.append(f"{cat:<30} {time_us/1000:>15.2f} {pct:>11.1f}%")
|
||||
report_lines.append("")
|
||||
|
||||
report = "\n".join(report_lines)
|
||||
return report
|
||||
|
||||
def save_results(self):
|
||||
"""Save all profiling results to files."""
|
||||
if not self.enabled:
|
||||
return
|
||||
|
||||
# Save report
|
||||
report = self.generate_report()
|
||||
report_path = os.path.join(self.output_dir, "profiling_report.txt")
|
||||
with open(report_path, 'w') as f:
|
||||
f.write(report)
|
||||
print(f">>> Profiling report saved to: {report_path}")
|
||||
|
||||
# Save detailed JSON data
|
||||
data = {
|
||||
'macro_timings': {
|
||||
name: {
|
||||
'times': times,
|
||||
'cuda_times': [ct for _, ct in self.cuda_events.get(name, [])]
|
||||
}
|
||||
for name, times in self.macro_timings.items()
|
||||
},
|
||||
'memory_snapshots': self.memory_snapshots,
|
||||
'operator_stats': self.operator_stats,
|
||||
}
|
||||
json_path = os.path.join(self.output_dir, "profiling_data.json")
|
||||
with open(json_path, 'w') as f:
|
||||
json.dump(data, f, indent=2)
|
||||
print(f">>> Detailed profiling data saved to: {json_path}")
|
||||
|
||||
# Print summary to console
|
||||
print("\n" + report)
|
||||
fut.result()
|
||||
except Exception as e:
|
||||
print(f">>> [async I/O] error: {e}")
|
||||
_io_futures.clear()
|
||||
|
||||
|
||||
# Global profiler instance
|
||||
_profiler: Optional[ProfilerManager] = None
|
||||
atexit.register(_flush_io)
|
||||
|
||||
def get_profiler() -> ProfilerManager:
|
||||
"""Get the global profiler instance."""
|
||||
global _profiler
|
||||
if _profiler is None:
|
||||
_profiler = ProfilerManager(enabled=False)
|
||||
return _profiler
|
||||
|
||||
def init_profiler(enabled: bool, output_dir: str, profile_detail: str) -> ProfilerManager:
|
||||
"""Initialize the global profiler."""
|
||||
global _profiler
|
||||
_profiler = ProfilerManager(
|
||||
enabled=enabled,
|
||||
output_dir=output_dir,
|
||||
profile_detail=profile_detail,
|
||||
)
|
||||
return _profiler
|
||||
def _save_results_sync(video_cpu: Tensor, filename: str, fps: int) -> None:
|
||||
"""Synchronous save on CPU tensor (runs in background thread)."""
|
||||
video = torch.clamp(video_cpu.float(), -1., 1.)
|
||||
n = video.shape[0]
|
||||
video = video.permute(2, 0, 1, 3, 4)
|
||||
frame_grids = [
|
||||
torchvision.utils.make_grid(framesheet, nrow=int(n), padding=0)
|
||||
for framesheet in video
|
||||
]
|
||||
grid = torch.stack(frame_grids, dim=0)
|
||||
grid = (grid + 1.0) / 2.0
|
||||
grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1)
|
||||
torchvision.io.write_video(filename,
|
||||
grid,
|
||||
fps=fps,
|
||||
video_codec='h264',
|
||||
options={'crf': '10'})
|
||||
|
||||
|
||||
def save_results_async(video: Tensor, filename: str, fps: int = 8) -> None:
|
||||
"""Submit video saving to background thread pool."""
|
||||
video_cpu = video.detach().cpu()
|
||||
fut = _get_io_executor().submit(_save_results_sync, video_cpu, filename, fps)
|
||||
_io_futures.append(fut)
|
||||
|
||||
|
||||
|
||||
# ========== Original Functions ==========
|
||||
@@ -458,7 +147,8 @@ def _load_state_dict(model: nn.Module,
|
||||
|
||||
def load_model_checkpoint(model: nn.Module,
|
||||
ckpt: str,
|
||||
assign: bool | None = None) -> nn.Module:
|
||||
assign: bool | None = None,
|
||||
device: str | torch.device = "cpu") -> nn.Module:
|
||||
"""Load model weights from checkpoint file.
|
||||
|
||||
Args:
|
||||
@@ -467,11 +157,12 @@ def load_model_checkpoint(model: nn.Module,
|
||||
assign (bool | None): Whether to preserve checkpoint tensor dtypes
|
||||
via load_state_dict(assign=True). If None, auto-enable when a
|
||||
casted checkpoint metadata is detected.
|
||||
device (str | torch.device): Target device for loaded tensors.
|
||||
|
||||
Returns:
|
||||
nn.Module: Model with loaded weights.
|
||||
"""
|
||||
ckpt_data = torch.load(ckpt, map_location="cpu")
|
||||
ckpt_data = torch.load(ckpt, map_location=device, mmap=True)
|
||||
use_assign = False
|
||||
if assign is not None:
|
||||
use_assign = assign
|
||||
@@ -511,9 +202,7 @@ def load_model_checkpoint(model: nn.Module,
|
||||
|
||||
def maybe_cast_module(module: nn.Module | None,
|
||||
dtype: torch.dtype,
|
||||
label: str,
|
||||
profiler: Optional[ProfilerManager] = None,
|
||||
profile_name: Optional[str] = None) -> None:
|
||||
label: str) -> None:
|
||||
if module is None:
|
||||
return
|
||||
try:
|
||||
@@ -524,10 +213,6 @@ def maybe_cast_module(module: nn.Module | None,
|
||||
if param.dtype == dtype:
|
||||
print(f">>> {label} already {dtype}; skip cast")
|
||||
return
|
||||
ctx = nullcontext()
|
||||
if profiler is not None and profile_name:
|
||||
ctx = profiler.profile_section(profile_name)
|
||||
with ctx:
|
||||
module.to(dtype=dtype)
|
||||
print(f">>> {label} cast to {dtype}")
|
||||
|
||||
@@ -746,8 +431,6 @@ def get_latent_z(model, videos: Tensor) -> Tensor:
|
||||
Returns:
|
||||
Tensor: Latent video tensor of shape [B, C, T, H, W].
|
||||
"""
|
||||
profiler = get_profiler()
|
||||
with profiler.profile_section("get_latent_z/encode"):
|
||||
b, c, t, h, w = videos.shape
|
||||
x = rearrange(videos, 'b c t h w -> (b t) c h w')
|
||||
vae_ctx = nullcontext()
|
||||
@@ -862,8 +545,6 @@ def image_guided_synthesis_sim_mode(
|
||||
actions (torch.Tensor): Predicted action sequences [B, T, D] from diffusion decoding.
|
||||
states (torch.Tensor): Predicted state sequences [B, T, D] from diffusion decoding.
|
||||
"""
|
||||
profiler = get_profiler()
|
||||
|
||||
b, _, t, _, _ = noise_shape
|
||||
ddim_sampler = getattr(model, "_ddim_sampler", None)
|
||||
if ddim_sampler is None:
|
||||
@@ -873,7 +554,6 @@ def image_guided_synthesis_sim_mode(
|
||||
|
||||
fs = torch.tensor([fs] * batch_size, dtype=torch.long, device=model.device)
|
||||
|
||||
with profiler.profile_section("synthesis/conditioning_prep"):
|
||||
img = observation['observation.images.top'].permute(0, 2, 1, 3, 4)
|
||||
cond_img = rearrange(img, 'b o c h w -> (b o) c h w')[-1:]
|
||||
if getattr(model, "encoder_bf16", False) and model.device.type == "cuda":
|
||||
@@ -959,7 +639,6 @@ def image_guided_synthesis_sim_mode(
|
||||
cond_z0 = None
|
||||
|
||||
if ddim_sampler is not None:
|
||||
with profiler.profile_section("synthesis/ddim_sampling"):
|
||||
autocast_ctx = nullcontext()
|
||||
if diffusion_autocast_dtype is not None and model.device.type == "cuda":
|
||||
autocast_ctx = torch.autocast("cuda", dtype=diffusion_autocast_dtype)
|
||||
@@ -982,7 +661,6 @@ def image_guided_synthesis_sim_mode(
|
||||
**kwargs)
|
||||
|
||||
# Reconstruct from latent to pixel space
|
||||
with profiler.profile_section("synthesis/decode_first_stage"):
|
||||
if getattr(model, "vae_bf16", False):
|
||||
if samples.dtype != torch.bfloat16:
|
||||
samples = samples.to(dtype=torch.bfloat16)
|
||||
@@ -1012,21 +690,31 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
profiler = get_profiler()
|
||||
|
||||
# Create inference and tensorboard dirs
|
||||
# Create inference dir
|
||||
os.makedirs(args.savedir + '/inference', exist_ok=True)
|
||||
log_dir = args.savedir + f"/tensorboard"
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
writer = SummaryWriter(log_dir=log_dir)
|
||||
|
||||
# Load prompt
|
||||
csv_path = os.path.join(args.prompt_dir, f"{args.dataset}.csv")
|
||||
df = pd.read_csv(csv_path)
|
||||
|
||||
# Load config
|
||||
with profiler.profile_section("model_loading/config"):
|
||||
# Load config (always needed for data setup)
|
||||
config = OmegaConf.load(args.config)
|
||||
|
||||
prepared_path = args.ckpt_path + ".prepared.pt"
|
||||
if os.path.exists(prepared_path):
|
||||
# ---- Fast path: load the fully-prepared model ----
|
||||
print(f">>> Loading prepared model from {prepared_path} ...")
|
||||
model = torch.load(prepared_path,
|
||||
map_location=f"cuda:{gpu_no}",
|
||||
weights_only=False,
|
||||
mmap=True)
|
||||
model.eval()
|
||||
diffusion_autocast_dtype = (torch.bfloat16
|
||||
if args.diffusion_dtype == "bf16"
|
||||
else None)
|
||||
print(f">>> Prepared model loaded.")
|
||||
else:
|
||||
# ---- Normal path: construct + checkpoint + casting ----
|
||||
config['model']['params']['wma_config']['params'][
|
||||
'use_checkpoint'] = False
|
||||
model = instantiate_from_config(config.model)
|
||||
@@ -1034,30 +722,18 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
|
||||
assert os.path.exists(args.ckpt_path), "Error: checkpoint Not Found!"
|
||||
|
||||
with profiler.profile_section("model_loading/checkpoint"):
|
||||
model = load_model_checkpoint(model, args.ckpt_path)
|
||||
model = load_model_checkpoint(model, args.ckpt_path,
|
||||
device=f"cuda:{gpu_no}")
|
||||
model.eval()
|
||||
model = model.cuda(gpu_no) # move residual buffers not in state_dict
|
||||
print(f'>>> Load pre-trained model ...')
|
||||
|
||||
# Build unnomalizer
|
||||
logging.info("***** Configing Data *****")
|
||||
with profiler.profile_section("data_loading"):
|
||||
data = instantiate_from_config(config.data)
|
||||
data.setup()
|
||||
print(">>> Dataset is successfully loaded ...")
|
||||
|
||||
with profiler.profile_section("model_to_cuda"):
|
||||
model = model.cuda(gpu_no)
|
||||
device = get_device_from_parameters(model)
|
||||
|
||||
diffusion_autocast_dtype = None
|
||||
if args.diffusion_dtype == "bf16":
|
||||
maybe_cast_module(
|
||||
model.model,
|
||||
torch.bfloat16,
|
||||
"diffusion backbone",
|
||||
profiler=profiler,
|
||||
profile_name="model_loading/diffusion_bf16",
|
||||
)
|
||||
diffusion_autocast_dtype = torch.bfloat16
|
||||
print(">>> diffusion backbone set to bfloat16")
|
||||
@@ -1068,12 +744,29 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
model.first_stage_model,
|
||||
vae_weight_dtype,
|
||||
"VAE",
|
||||
profiler=profiler,
|
||||
profile_name="model_loading/vae_cast",
|
||||
)
|
||||
model.vae_bf16 = args.vae_dtype == "bf16"
|
||||
print(f">>> VAE dtype set to {args.vae_dtype}")
|
||||
|
||||
# --- VAE performance optimizations ---
|
||||
if hasattr(model, "first_stage_model") and model.first_stage_model is not None:
|
||||
vae = model.first_stage_model
|
||||
|
||||
# torch.compile: fuses GroupNorm+SiLU, conv chains, etc.
|
||||
if args.vae_compile:
|
||||
vae.decoder = torch.compile(vae.decoder, mode="reduce-overhead")
|
||||
vae.encoder = torch.compile(vae.encoder, mode="reduce-overhead")
|
||||
print(">>> VAE encoder/decoder compiled with torch.compile (reduce-overhead)")
|
||||
|
||||
# Batch decode size
|
||||
vae_decode_bs = args.vae_decode_bs if args.vae_decode_bs > 0 else 9999
|
||||
model.vae_decode_bs = vae_decode_bs
|
||||
model.vae_encode_bs = vae_decode_bs
|
||||
if args.vae_decode_bs > 0:
|
||||
print(f">>> VAE encode/decode batch size set to {args.vae_decode_bs}")
|
||||
else:
|
||||
print(">>> VAE encode/decode batch size: all frames at once")
|
||||
|
||||
encoder_mode = args.encoder_mode
|
||||
encoder_bf16 = encoder_mode in ("autocast", "bf16_full")
|
||||
encoder_weight_dtype = torch.bfloat16 if encoder_mode == "bf16_full" else torch.float32
|
||||
@@ -1082,16 +775,12 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
model.cond_stage_model,
|
||||
encoder_weight_dtype,
|
||||
"cond_stage_model",
|
||||
profiler=profiler,
|
||||
profile_name="model_loading/encoder_cond_cast",
|
||||
)
|
||||
if hasattr(model, "embedder") and model.embedder is not None:
|
||||
maybe_cast_module(
|
||||
model.embedder,
|
||||
encoder_weight_dtype,
|
||||
"embedder",
|
||||
profiler=profiler,
|
||||
profile_name="model_loading/encoder_embedder_cast",
|
||||
)
|
||||
model.encoder_bf16 = encoder_bf16
|
||||
model.encoder_mode = encoder_mode
|
||||
@@ -1107,24 +796,18 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
model.image_proj_model,
|
||||
projector_weight_dtype,
|
||||
"image_proj_model",
|
||||
profiler=profiler,
|
||||
profile_name="model_loading/projector_image_cast",
|
||||
)
|
||||
if hasattr(model, "state_projector") and model.state_projector is not None:
|
||||
maybe_cast_module(
|
||||
model.state_projector,
|
||||
projector_weight_dtype,
|
||||
"state_projector",
|
||||
profiler=profiler,
|
||||
profile_name="model_loading/projector_state_cast",
|
||||
)
|
||||
if hasattr(model, "action_projector") and model.action_projector is not None:
|
||||
maybe_cast_module(
|
||||
model.action_projector,
|
||||
projector_weight_dtype,
|
||||
"action_projector",
|
||||
profiler=profiler,
|
||||
profile_name="model_loading/projector_action_cast",
|
||||
)
|
||||
if hasattr(model, "projector_bf16"):
|
||||
model.projector_bf16 = projector_bf16
|
||||
@@ -1148,7 +831,18 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
print(">>> export_only set; skipping inference.")
|
||||
return
|
||||
|
||||
profiler.record_memory("after_model_load")
|
||||
# Save prepared model for fast loading next time
|
||||
if prepared_path:
|
||||
print(f">>> Saving prepared model to {prepared_path} ...")
|
||||
torch.save(model, prepared_path)
|
||||
print(f">>> Prepared model saved ({os.path.getsize(prepared_path) / 1024**3:.1f} GB).")
|
||||
|
||||
# Build normalizer (always needed, independent of model loading path)
|
||||
logging.info("***** Configing Data *****")
|
||||
data = instantiate_from_config(config.data)
|
||||
data.setup()
|
||||
print(">>> Dataset is successfully loaded ...")
|
||||
device = get_device_from_parameters(model)
|
||||
|
||||
# Run over data
|
||||
assert (args.height % 16 == 0) and (
|
||||
@@ -1163,10 +857,6 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
print(f'>>> Generate {n_frames} frames under each generation ...')
|
||||
noise_shape = [args.bs, channels, n_frames, h, w]
|
||||
|
||||
# Determine profiler iterations
|
||||
profile_active_iters = getattr(args, 'profile_iterations', 3)
|
||||
use_pytorch_profiler = profiler.enabled and profile_active_iters > 0
|
||||
|
||||
# Start inference
|
||||
for idx in range(0, len(df)):
|
||||
sample = df.iloc[idx]
|
||||
@@ -1182,7 +872,6 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
|
||||
# Load transitions to get the initial state later
|
||||
transition_path = get_transition_path(args.prompt_dir, sample)
|
||||
with profiler.profile_section("load_transitions"):
|
||||
with h5py.File(transition_path, 'r') as h5f:
|
||||
transition_dict = {}
|
||||
for key in h5f.keys():
|
||||
@@ -1210,7 +899,6 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
}
|
||||
|
||||
# Obtain initial frame and state
|
||||
with profiler.profile_section("prepare_init_input"):
|
||||
start_idx = 0
|
||||
model_input_fs = ori_fps // fs
|
||||
batch, ori_state_dim, ori_action_dim = prepare_init_input(
|
||||
@@ -1233,24 +921,12 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
# Update observation queues
|
||||
cond_obs_queues = populate_queues(cond_obs_queues, observation)
|
||||
|
||||
# Setup PyTorch profiler context if enabled
|
||||
pytorch_prof_ctx = nullcontext()
|
||||
if use_pytorch_profiler:
|
||||
pytorch_prof_ctx = profiler.start_pytorch_profiler(
|
||||
wait=1, warmup=1, active=profile_active_iters
|
||||
)
|
||||
|
||||
# Multi-round interaction with the world-model
|
||||
with pytorch_prof_ctx:
|
||||
for itr in tqdm(range(args.n_iter)):
|
||||
log_every = max(1, args.step_log_every)
|
||||
log_step = (itr % log_every == 0)
|
||||
profiler.current_iteration = itr
|
||||
profiler.record_memory(f"iter_{itr}_start")
|
||||
|
||||
with profiler.profile_section("iteration_total"):
|
||||
# Get observation
|
||||
with profiler.profile_section("prepare_observation"):
|
||||
observation = {
|
||||
'observation.images.top':
|
||||
torch.stack(list(
|
||||
@@ -1267,7 +943,6 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
# Use world-model in policy to generate action
|
||||
if log_step:
|
||||
print(f'>>> Step {itr}: generating actions ...')
|
||||
with profiler.profile_section("action_generation"):
|
||||
pred_videos_0, pred_actions, _ = image_guided_synthesis_sim_mode(
|
||||
model,
|
||||
sample['instruction'],
|
||||
@@ -1285,7 +960,6 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
diffusion_autocast_dtype=diffusion_autocast_dtype)
|
||||
|
||||
# Update future actions in the observation queues
|
||||
with profiler.profile_section("update_action_queues"):
|
||||
for act_idx in range(len(pred_actions[0])):
|
||||
obs_update = {'action': pred_actions[0][act_idx:act_idx + 1]}
|
||||
obs_update['action'][:, ori_action_dim:] = 0.0
|
||||
@@ -1293,7 +967,6 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
obs_update)
|
||||
|
||||
# Collect data for interacting the world-model using the predicted actions
|
||||
with profiler.profile_section("prepare_wm_observation"):
|
||||
observation = {
|
||||
'observation.images.top':
|
||||
torch.stack(list(
|
||||
@@ -1310,7 +983,6 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
# Interaction with the world-model
|
||||
if log_step:
|
||||
print(f'>>> Step {itr}: interacting with world model ...')
|
||||
with profiler.profile_section("world_model_interaction"):
|
||||
pred_videos_1, _, pred_states = image_guided_synthesis_sim_mode(
|
||||
model,
|
||||
"",
|
||||
@@ -1327,7 +999,6 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
guidance_rescale=args.guidance_rescale,
|
||||
diffusion_autocast_dtype=diffusion_autocast_dtype)
|
||||
|
||||
with profiler.profile_section("update_state_queues"):
|
||||
for step_idx in range(args.exe_steps):
|
||||
obs_update = {
|
||||
'observation.images.top':
|
||||
@@ -1342,28 +1013,14 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
cond_obs_queues = populate_queues(cond_obs_queues,
|
||||
obs_update)
|
||||
|
||||
# Save the imagen videos for decision-making
|
||||
with profiler.profile_section("save_results"):
|
||||
sample_tag = f"{args.dataset}-vid{sample['videoid']}-dm-fs-{fs}/itr-{itr}"
|
||||
log_to_tensorboard(writer,
|
||||
pred_videos_0,
|
||||
sample_tag,
|
||||
fps=args.save_fps)
|
||||
# Save videos environment changes via world-model interaction
|
||||
sample_tag = f"{args.dataset}-vid{sample['videoid']}-wd-fs-{fs}/itr-{itr}"
|
||||
log_to_tensorboard(writer,
|
||||
pred_videos_1,
|
||||
sample_tag,
|
||||
fps=args.save_fps)
|
||||
|
||||
# Save the imagen videos for decision-making
|
||||
# Save the imagen videos for decision-making (async)
|
||||
sample_video_file = f'{video_save_dir}/dm/{fs}/itr-{itr}.mp4'
|
||||
save_results(pred_videos_0.cpu(),
|
||||
save_results_async(pred_videos_0,
|
||||
sample_video_file,
|
||||
fps=args.save_fps)
|
||||
# Save videos environment changes via world-model interaction
|
||||
sample_video_file = f'{video_save_dir}/wm/{fs}/itr-{itr}.mp4'
|
||||
save_results(pred_videos_1.cpu(),
|
||||
save_results_async(pred_videos_1,
|
||||
sample_video_file,
|
||||
fps=args.save_fps)
|
||||
|
||||
@@ -1371,20 +1028,12 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
# Collect the result of world-model interactions
|
||||
wm_video.append(pred_videos_1[:, :, :args.exe_steps].cpu())
|
||||
|
||||
profiler.record_memory(f"iter_{itr}_end")
|
||||
profiler.step_profiler()
|
||||
|
||||
full_video = torch.cat(wm_video, dim=2)
|
||||
sample_tag = f"{args.dataset}-vid{sample['videoid']}-wd-fs-{fs}/full"
|
||||
log_to_tensorboard(writer,
|
||||
full_video,
|
||||
sample_tag,
|
||||
fps=args.save_fps)
|
||||
sample_full_video_file = f"{video_save_dir}/../{sample['videoid']}_full_fs{fs}.mp4"
|
||||
save_results(full_video, sample_full_video_file, fps=args.save_fps)
|
||||
save_results_async(full_video, sample_full_video_file, fps=args.save_fps)
|
||||
|
||||
# Save profiling results
|
||||
profiler.save_results()
|
||||
# Wait for all async I/O to complete
|
||||
_flush_io()
|
||||
|
||||
|
||||
def get_parser():
|
||||
@@ -1511,6 +1160,18 @@ def get_parser():
|
||||
default="fp32",
|
||||
help="Dtype for VAE/first_stage_model weights and forward autocast."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vae_compile",
|
||||
action='store_true',
|
||||
default=False,
|
||||
help="Apply torch.compile to VAE decoder for kernel fusion."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vae_decode_bs",
|
||||
type=int,
|
||||
default=0,
|
||||
help="VAE decode batch size (0=all frames at once). Reduces kernel launch overhead."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--export_casted_ckpt",
|
||||
type=str,
|
||||
@@ -1556,32 +1217,6 @@ def get_parser():
|
||||
type=int,
|
||||
default=8,
|
||||
help="fps for the saving video")
|
||||
# Profiling arguments
|
||||
parser.add_argument(
|
||||
"--profile",
|
||||
action='store_true',
|
||||
default=False,
|
||||
help="Enable performance profiling (macro and operator-level analysis)."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--profile_output_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Directory to save profiling results. Defaults to {savedir}/profile_output."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--profile_iterations",
|
||||
type=int,
|
||||
default=3,
|
||||
help="Number of iterations to run PyTorch profiler's active phase for operator-level analysis."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--profile_detail",
|
||||
type=str,
|
||||
choices=["light", "full"],
|
||||
default="light",
|
||||
help="Profiling detail level. Use 'full' for shapes/stacks/memory/flops."
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
@@ -1593,15 +1228,5 @@ if __name__ == '__main__':
|
||||
seed = random.randint(0, 2**31)
|
||||
seed_everything(seed)
|
||||
|
||||
# Initialize profiler
|
||||
profile_output_dir = args.profile_output_dir
|
||||
if profile_output_dir is None:
|
||||
profile_output_dir = os.path.join(args.savedir, "profile_output")
|
||||
init_profiler(
|
||||
enabled=args.profile,
|
||||
output_dir=profile_output_dir,
|
||||
profile_detail=args.profile_detail,
|
||||
)
|
||||
|
||||
rank, gpu_num = 0, 1
|
||||
run_inference(args, gpu_num, rank)
|
||||
|
||||
@@ -99,7 +99,6 @@ class AutoencoderKL(pl.LightningModule):
|
||||
print(f"Restored from {path}")
|
||||
|
||||
def encode(self, x, **kwargs):
|
||||
|
||||
h = self.encoder(x)
|
||||
moments = self.quant_conv(h)
|
||||
posterior = DiagonalGaussianDistribution(moments)
|
||||
|
||||
@@ -1073,11 +1073,15 @@ class LatentDiffusion(DDPM):
|
||||
if not self.perframe_ae:
|
||||
encoder_posterior = self.first_stage_model.encode(x)
|
||||
results = self.get_first_stage_encoding(encoder_posterior).detach()
|
||||
else: ## Consume less GPU memory but slower
|
||||
else: ## Batch encode with configurable batch size
|
||||
bs = getattr(self, 'vae_encode_bs', 1)
|
||||
if bs >= x.shape[0]:
|
||||
encoder_posterior = self.first_stage_model.encode(x)
|
||||
results = self.get_first_stage_encoding(encoder_posterior).detach()
|
||||
else:
|
||||
results = []
|
||||
for index in range(x.shape[0]):
|
||||
frame_batch = self.first_stage_model.encode(x[index:index +
|
||||
1, :, :, :])
|
||||
for i in range(0, x.shape[0], bs):
|
||||
frame_batch = self.first_stage_model.encode(x[i:i + bs])
|
||||
frame_result = self.get_first_stage_encoding(
|
||||
frame_batch).detach()
|
||||
results.append(frame_result)
|
||||
@@ -1105,15 +1109,20 @@ class LatentDiffusion(DDPM):
|
||||
else:
|
||||
reshape_back = False
|
||||
|
||||
if not self.perframe_ae:
|
||||
z = 1. / self.scale_factor * z
|
||||
|
||||
if not self.perframe_ae:
|
||||
results = self.first_stage_model.decode(z, **kwargs)
|
||||
else:
|
||||
bs = getattr(self, 'vae_decode_bs', 1)
|
||||
if bs >= z.shape[0]:
|
||||
# all frames in one batch
|
||||
results = self.first_stage_model.decode(z, **kwargs)
|
||||
else:
|
||||
results = []
|
||||
for index in range(z.shape[0]):
|
||||
frame_z = 1. / self.scale_factor * z[index:index + 1, :, :, :]
|
||||
frame_result = self.first_stage_model.decode(frame_z, **kwargs)
|
||||
results.append(frame_result)
|
||||
for i in range(0, z.shape[0], bs):
|
||||
results.append(
|
||||
self.first_stage_model.decode(z[i:i + bs], **kwargs))
|
||||
results = torch.cat(results, dim=0)
|
||||
|
||||
if reshape_back:
|
||||
|
||||
@@ -55,16 +55,13 @@ class DDIMSampler(object):
|
||||
to_torch(self.model.alphas_cumprod_prev))
|
||||
|
||||
# Calculations for diffusion q(x_t | x_{t-1}) and others
|
||||
self.register_buffer('sqrt_alphas_cumprod',
|
||||
to_torch(np.sqrt(alphas_cumprod.cpu())))
|
||||
self.register_buffer('sqrt_one_minus_alphas_cumprod',
|
||||
to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
|
||||
self.register_buffer('log_one_minus_alphas_cumprod',
|
||||
to_torch(np.log(1. - alphas_cumprod.cpu())))
|
||||
self.register_buffer('sqrt_recip_alphas_cumprod',
|
||||
to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
|
||||
self.register_buffer('sqrt_recipm1_alphas_cumprod',
|
||||
to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
|
||||
# Computed directly on GPU to avoid CPU↔GPU transfers
|
||||
ac = to_torch(alphas_cumprod)
|
||||
self.register_buffer('sqrt_alphas_cumprod', ac.sqrt())
|
||||
self.register_buffer('sqrt_one_minus_alphas_cumprod', (1. - ac).sqrt())
|
||||
self.register_buffer('log_one_minus_alphas_cumprod', (1. - ac).log())
|
||||
self.register_buffer('sqrt_recip_alphas_cumprod', ac.rsqrt())
|
||||
self.register_buffer('sqrt_recipm1_alphas_cumprod', (1. / ac - 1).sqrt())
|
||||
|
||||
# DDIM sampling parameters
|
||||
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
|
||||
@@ -86,6 +83,11 @@ class DDIMSampler(object):
|
||||
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
|
||||
self.register_buffer('ddim_sqrt_one_minus_alphas',
|
||||
torch.sqrt(1. - ddim_alphas))
|
||||
# Precomputed coefficients for DDIM update formula
|
||||
self.register_buffer('ddim_sqrt_alphas', ddim_alphas.sqrt())
|
||||
self.register_buffer('ddim_sqrt_alphas_prev', ddim_alphas_prev.sqrt())
|
||||
self.register_buffer('ddim_dir_coeff',
|
||||
(1. - ddim_alphas_prev - ddim_sigmas**2).sqrt())
|
||||
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
|
||||
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) *
|
||||
(1 - self.alphas_cumprod / self.alphas_cumprod_prev))
|
||||
@@ -208,18 +210,11 @@ class DDIMSampler(object):
|
||||
dp_ddim_scheduler_state = self.model.dp_noise_scheduler_state
|
||||
|
||||
b = shape[0]
|
||||
if x_T is None:
|
||||
img = torch.randn(shape, device=device)
|
||||
action = torch.randn((b, 16, self.model.agent_action_dim),
|
||||
device=device)
|
||||
state = torch.randn((b, 16, self.model.agent_state_dim),
|
||||
device=device)
|
||||
else:
|
||||
img = x_T
|
||||
action = torch.randn((b, 16, self.model.agent_action_dim),
|
||||
device=device)
|
||||
state = torch.randn((b, 16, self.model.agent_state_dim),
|
||||
device=device)
|
||||
img = torch.randn(shape, device=device) if x_T is None else x_T
|
||||
|
||||
if precision is not None:
|
||||
if precision == 16:
|
||||
@@ -362,12 +357,13 @@ class DDIMSampler(object):
|
||||
**kwargs)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
model_output = e_t_uncond + unconditional_guidance_scale * (
|
||||
e_t_cond - e_t_uncond)
|
||||
model_output_action = e_t_uncond_action + unconditional_guidance_scale * (
|
||||
e_t_cond_action - e_t_uncond_action)
|
||||
model_output_state = e_t_uncond_state + unconditional_guidance_scale * (
|
||||
e_t_cond_state - e_t_uncond_state)
|
||||
model_output = torch.lerp(e_t_uncond, e_t_cond,
|
||||
unconditional_guidance_scale)
|
||||
model_output_action = torch.lerp(e_t_uncond_action,
|
||||
e_t_cond_action,
|
||||
unconditional_guidance_scale)
|
||||
model_output_state = torch.lerp(e_t_uncond_state, e_t_cond_state,
|
||||
unconditional_guidance_scale)
|
||||
|
||||
if guidance_rescale > 0.0:
|
||||
model_output = rescale_noise_cfg(
|
||||
@@ -396,18 +392,28 @@ class DDIMSampler(object):
|
||||
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
|
||||
sigmas = self.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
|
||||
|
||||
if use_original_steps:
|
||||
sqrt_alphas = alphas.sqrt()
|
||||
sqrt_alphas_prev = alphas_prev.sqrt()
|
||||
dir_coeffs = (1. - alphas_prev - sigmas**2).sqrt()
|
||||
else:
|
||||
sqrt_alphas = self.ddim_sqrt_alphas
|
||||
sqrt_alphas_prev = self.ddim_sqrt_alphas_prev
|
||||
dir_coeffs = self.ddim_dir_coeff
|
||||
|
||||
if is_video:
|
||||
size = (1, 1, 1, 1, 1)
|
||||
else:
|
||||
size = (1, 1, 1, 1)
|
||||
|
||||
a_t = alphas[index].view(size)
|
||||
a_prev = alphas_prev[index].view(size)
|
||||
sqrt_at = sqrt_alphas[index].view(size)
|
||||
sqrt_a_prev = sqrt_alphas_prev[index].view(size)
|
||||
sigma_t = sigmas[index].view(size)
|
||||
sqrt_one_minus_at = sqrt_one_minus_alphas[index].view(size)
|
||||
dir_coeff = dir_coeffs[index].view(size)
|
||||
|
||||
if self.model.parameterization != "v":
|
||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / sqrt_at
|
||||
else:
|
||||
pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
|
||||
|
||||
@@ -420,14 +426,11 @@ class DDIMSampler(object):
|
||||
if quantize_denoised:
|
||||
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
||||
|
||||
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
||||
|
||||
noise = sigma_t * noise_like(x.shape, device,
|
||||
repeat_noise) * temperature
|
||||
if noise_dropout > 0.:
|
||||
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
||||
|
||||
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
||||
x_prev = sqrt_a_prev * pred_x0 + dir_coeff * e_t + noise
|
||||
|
||||
return x_prev, pred_x0, model_output_action, model_output_state
|
||||
|
||||
@@ -475,7 +478,7 @@ class DDIMSampler(object):
|
||||
sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
|
||||
sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
|
||||
else:
|
||||
sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
|
||||
sqrt_alphas_cumprod = self.ddim_sqrt_alphas
|
||||
sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
|
||||
|
||||
if noise is None:
|
||||
|
||||
@@ -10,8 +10,8 @@ from unifolm_wma.utils.utils import instantiate_from_config
|
||||
|
||||
|
||||
def nonlinearity(x):
|
||||
# swish
|
||||
return x * torch.sigmoid(x)
|
||||
# swish / SiLU — single fused CUDA kernel instead of x * sigmoid(x)
|
||||
return torch.nn.functional.silu(x)
|
||||
|
||||
|
||||
def Normalize(in_channels, num_groups=32):
|
||||
|
||||
@@ -7,79 +7,97 @@ MACRO-LEVEL TIMING SUMMARY
|
||||
----------------------------------------
|
||||
Section Count Total(ms) Avg(ms) CUDA Avg(ms)
|
||||
--------------------------------------------------------------------------------------
|
||||
action_generation 11 399707.47 36337.04 36336.85
|
||||
data_loading 1 52.85 52.85 52.88
|
||||
get_latent_z/encode 22 901.39 40.97 41.01
|
||||
iteration_total 11 836793.23 76072.11 76071.63
|
||||
load_transitions 1 2.24 2.24 2.28
|
||||
model_loading/checkpoint 1 11833.31 11833.31 11833.43
|
||||
model_loading/config 1 49774.19 49774.19 49774.16
|
||||
model_to_cuda 1 8909.30 8909.30 8909.33
|
||||
prepare_init_input 1 10.52 10.52 10.55
|
||||
prepare_observation 11 5.41 0.49 0.53
|
||||
prepare_wm_observation 11 2.12 0.19 0.22
|
||||
save_results 11 38668.06 3515.28 3515.32
|
||||
synthesis/conditioning_prep 22 2916.63 132.57 132.61
|
||||
synthesis/ddim_sampling 22 782695.01 35577.05 35576.86
|
||||
synthesis/decode_first_stage 22 12444.31 565.65 565.70
|
||||
update_action_queues 11 6.85 0.62 0.65
|
||||
update_state_queues 11 17.67 1.61 1.64
|
||||
world_model_interaction 11 398375.58 36215.96 36215.75
|
||||
action_generation 11 173133.54 15739.41 15739.36
|
||||
data_loading 1 54.31 54.31 54.34
|
||||
get_latent_z/encode 22 785.25 35.69 35.72
|
||||
iteration_total 11 386482.08 35134.73 35134.55
|
||||
load_transitions 1 2.07 2.07 2.10
|
||||
model_loading/prepared 1 4749.22 4749.22 4749.83
|
||||
prepare_init_input 1 29.19 29.19 29.22
|
||||
prepare_observation 11 5.49 0.50 0.53
|
||||
prepare_wm_observation 11 1.93 0.18 0.20
|
||||
save_results 11 38791.18 3526.47 3526.51
|
||||
synthesis/conditioning_prep 22 2528.23 114.92 114.95
|
||||
synthesis/ddim_sampling 22 336003.29 15272.88 15272.83
|
||||
synthesis/decode_first_stage 22 9095.14 413.42 413.46
|
||||
update_action_queues 11 7.28 0.66 0.69
|
||||
update_state_queues 11 17.38 1.58 1.61
|
||||
world_model_interaction 11 174516.52 15865.14 15865.07
|
||||
--------------------------------------------------------------------------------------
|
||||
TOTAL 2543116.13
|
||||
TOTAL 1126202.08
|
||||
|
||||
----------------------------------------
|
||||
GPU MEMORY SUMMARY
|
||||
----------------------------------------
|
||||
Peak allocated: 17890.50 MB
|
||||
Average allocated: 16129.98 MB
|
||||
Peak allocated: 18188.29 MB
|
||||
Average allocated: 9117.49 MB
|
||||
|
||||
----------------------------------------
|
||||
TOP 30 OPERATORS BY CUDA TIME
|
||||
----------------------------------------
|
||||
Operator Count CUDA(ms) CPU(ms) Self CUDA(ms)
|
||||
------------------------------------------------------------------------------------------------
|
||||
ProfilerStep* 6 443804.16 237696.98 237689.25
|
||||
aten::linear 171276 112286.23 13179.82 0.00
|
||||
aten::addmm 81456 79537.36 3799.84 79296.37
|
||||
ampere_sgemm_128x64_tn 26400 52052.10 0.00 52052.10
|
||||
aten::matmul 90468 34234.05 6281.32 0.00
|
||||
aten::_convolution 100242 33623.79 13105.89 0.00
|
||||
aten::mm 89820 33580.74 3202.22 33253.18
|
||||
aten::convolution 100242 33575.23 13714.47 0.00
|
||||
aten::cudnn_convolution 98430 30932.19 8640.50 29248.12
|
||||
ampere_sgemm_32x128_tn 42348 20394.52 0.00 20394.52
|
||||
aten::conv2d 42042 18115.35 5932.30 0.00
|
||||
ampere_sgemm_128x32_tn 40938 16429.81 0.00 16429.81
|
||||
xformers::efficient_attention_forward_cutlass 24000 15222.23 2532.93 15120.44
|
||||
fmha_cutlassF_f32_aligned_64x64_rf_sm80(Attenti... 24000 15121.31 0.00 15121.31
|
||||
ampere_sgemm_64x64_tn 21000 14627.12 0.00 14627.12
|
||||
aten::copy_ 231819 14504.87 127056.51 14038.39
|
||||
aten::group_norm 87144 12033.73 10659.57 0.00
|
||||
aten::native_group_norm 87144 11473.40 9449.36 11002.02
|
||||
aten::conv3d 26400 8852.13 3365.43 0.00
|
||||
void at::native::(anonymous namespace)::Rowwise... 87144 8714.68 0.00 8714.68
|
||||
void cudnn::ops::nchwToNhwcKernel<float, float,... 169824 8525.44 0.00 8525.44
|
||||
aten::clone 214314 8200.26 8568.82 0.00
|
||||
void at::native::elementwise_kernel<128, 2, at:... 220440 8109.62 0.00 8109.62
|
||||
void cutlass::Kernel<cutlass_80_simt_sgemm_128x... 15000 7919.30 0.00 7919.30
|
||||
aten::_to_copy 12219 5963.43 122411.53 0.00
|
||||
aten::to 58101 5952.65 122443.72 0.00
|
||||
aten::conv1d 30000 5878.95 4556.48 0.00
|
||||
Memcpy HtoD (Pageable -> Device) 6696 5856.39 0.00 5856.39
|
||||
aten::reshape 671772 5124.03 9636.01 0.00
|
||||
sm80_xmma_fprop_implicit_gemm_indexed_tf32f32_t... 16272 5097.70 0.00 5097.70
|
||||
ProfilerStep* 18 690146.23 133688.74 616385.44
|
||||
aten::group_norm 168624 24697.84 29217.27 0.00
|
||||
aten::_convolution 96450 21420.26 12845.86 0.00
|
||||
aten::convolution 96450 21408.68 13480.97 0.00
|
||||
aten::linear 297398 20780.15 26257.38 0.00
|
||||
aten::cudnn_convolution 94638 18660.24 8239.04 18329.28
|
||||
aten::copy_ 772677 18135.46 17387.09 17864.87
|
||||
aten::conv3d 52800 12922.42 8572.58 0.00
|
||||
aten::conv2d 52469 12747.13 7725.70 0.00
|
||||
aten::native_group_norm 84312 10285.37 8974.31 10197.66
|
||||
aten::_to_copy 590277 10270.09 22570.90 0.00
|
||||
aten::to 602979 9655.26 23666.06 0.00
|
||||
aten::conv1d 56245 8174.37 10015.24 0.00
|
||||
void at::native::(anonymous namespace)::Rowwise... 84312 7979.71 0.00 7979.71
|
||||
aten::clone 177132 7502.90 7007.48 0.00
|
||||
void cudnn::ops::nchwToNhwcKernel<__nv_bfloat16... 164700 7384.52 0.00 7384.52
|
||||
aten::addmm 81456 6958.44 3903.01 6908.44
|
||||
aten::layer_norm 65700 5698.92 7816.08 0.00
|
||||
void at::native::elementwise_kernel<128, 4, at:... 149688 5372.46 0.00 5372.46
|
||||
void at::native::unrolled_elementwise_kernel<at... 180120 5165.28 0.00 5165.28
|
||||
ampere_bf16_s16816gemm_bf16_128x128_ldg8_relu_f... 24900 4449.05 0.00 4449.05
|
||||
void at::native::unrolled_elementwise_kernel<at... 368664 4405.30 0.00 4405.30
|
||||
aten::reshape 686778 3771.84 8309.51 0.00
|
||||
aten::contiguous 46008 3400.88 1881.73 0.00
|
||||
sm80_xmma_fprop_implicit_gemm_bf16bf16_bf16f32_... 15516 3398.03 0.00 3398.03
|
||||
aten::matmul 90489 3366.62 4946.69 0.00
|
||||
aten::mm 89820 3284.53 3308.76 3228.56
|
||||
void at::native::elementwise_kernel<128, 2, at:... 46518 2441.55 0.00 2441.55
|
||||
aten::add 113118 2426.66 2776.23 2385.52
|
||||
void at::native::elementwise_kernel<128, 4, at:... 104550 2426.41 0.00 2426.41
|
||||
|
||||
----------------------------------------
|
||||
OPERATOR CATEGORY BREAKDOWN
|
||||
----------------------------------------
|
||||
Category CUDA Time(ms) Percentage
|
||||
---------------------------------------------------------
|
||||
Other 481950.47 41.9%
|
||||
Linear/GEMM 342333.09 29.8%
|
||||
Convolution 159920.77 13.9%
|
||||
Elementwise 54682.93 4.8%
|
||||
Memory 36883.36 3.2%
|
||||
Attention 34736.13 3.0%
|
||||
Normalization 32081.19 2.8%
|
||||
Activation 6449.19 0.6%
|
||||
Other 723472.91 71.9%
|
||||
Convolution 114469.81 11.4%
|
||||
Memory 53845.46 5.4%
|
||||
Normalization 46852.57 4.7%
|
||||
Linear/GEMM 35354.58 3.5%
|
||||
Elementwise 17078.44 1.7%
|
||||
Activation 12296.29 1.2%
|
||||
Attention 2956.61 0.3%
|
||||
|
||||
------------------------------------------------------------------------------------------
|
||||
aten::addmm (Linear/GEMM) UTILIZATION ANALYSIS ON A100
|
||||
------------------------------------------------------------------------------------------
|
||||
Effective compute precision: BF16 Tensor Core (312 TFLOPS)
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
|
||||
Metric Value
|
||||
-------------------------------------------------------------------
|
||||
Total aten::addmm calls 81,456
|
||||
Total Self CUDA time 6908.44 ms
|
||||
Total FLOPs (profiler) 1.33 PFLOPS
|
||||
Achieved throughput 191.88 TFLOPS/s
|
||||
A100 peak throughput 312.00 TFLOPS/s
|
||||
MFU (Model FLOPs Utilization) 61.50%
|
||||
|
||||
INTERPRETATION:
|
||||
-------------------------------------------------------------------
|
||||
Good utilization (>60%). GEMM kernels are compute-bound
|
||||
and running efficiently on Tensor Cores.
|
||||
|
||||
@@ -2,7 +2,7 @@ res_dir="unitree_g1_pack_camera/case1"
|
||||
dataset="unitree_g1_pack_camera"
|
||||
|
||||
{
|
||||
time CUDA_VISIBLE_DEVICES=1 python3 scripts/evaluation/world_model_interaction.py \
|
||||
time CUDA_VISIBLE_DEVICES=0 python3 scripts/evaluation/world_model_interaction.py \
|
||||
--seed 123 \
|
||||
--ckpt_path ckpts/unifolm_wma_dual_mixbf16.ckpt \
|
||||
--config configs/inference/world_model_interaction.yaml \
|
||||
@@ -23,5 +23,6 @@ dataset="unitree_g1_pack_camera"
|
||||
--perframe_ae \
|
||||
--diffusion_dtype bf16 \
|
||||
--projector_mode bf16_full \
|
||||
--encoder_mode bf16_full
|
||||
--encoder_mode bf16_full \
|
||||
--vae_dtype bf16
|
||||
} 2>&1 | tee "${res_dir}/output.log"
|
||||
|
||||
Reference in New Issue
Block a user