1235 lines
40 KiB
Markdown
1235 lines
40 KiB
Markdown
# UnifoLM World Model Action - 模型架构详细分析
|
||
|
||
## 目录
|
||
1. [整体架构概览](#整体架构概览)
|
||
2. [推理流程分析](#推理流程分析)
|
||
3. [核心组件详解](#核心组件详解)
|
||
4. [性能瓶颈分析](#性能瓶颈分析)
|
||
5. [内核融合优化建议](#内核融合优化建议)
|
||
|
||
---
|
||
|
||
## 1. 整体架构概览
|
||
|
||
### 1.1 模型层次结构
|
||
|
||
```
|
||
DDPM (顶层模型)
|
||
├── DiffusionWrapper (条件包装器)
|
||
│ └── UNet3D (核心扩散模型)
|
||
│ ├── 时间嵌入 (Time Embedding)
|
||
│ ├── 下采样块 (Downsampling Blocks)
|
||
│ ├── 中间块 (Middle Blocks)
|
||
│ └── 上采样块 (Upsampling Blocks)
|
||
├── VAE (变分自编码器)
|
||
│ ├── Encoder (编码器)
|
||
│ └── Decoder (解码器)
|
||
├── CLIP Image Encoder (图像编码器)
|
||
├── Text Encoder (文本编码器)
|
||
├── State Projector (状态投影器)
|
||
└── Action Projector (动作投影器)
|
||
```
|
||
|
||
### 1.2 推理阶段数据流
|
||
|
||
```
|
||
输入观测 (Observation)
|
||
↓
|
||
[1] 条件编码阶段
|
||
├── 图像 → CLIP Encoder → Image Embedding
|
||
├── 图像 → VAE Encoder → Latent Condition
|
||
├── 文本 → Text Encoder → Text Embedding
|
||
├── 状态 → State Projector → State Embedding
|
||
└── 动作 → Action Projector → Action Embedding
|
||
↓
|
||
[2] DDIM采样阶段 (n步迭代)
|
||
├── 初始化噪声 x_T
|
||
└── For step in [0, n]:
|
||
├── 模型前向传播 (UNet3D)
|
||
│ ├── 时间步嵌入
|
||
│ ├── 条件注入 (CrossAttention)
|
||
│ └── 噪声预测
|
||
├── DDIM更新公式
|
||
└── x_{t-1} = f(x_t, noise_pred)
|
||
↓
|
||
[3] VAE解码阶段
|
||
└── Latent → VAE Decoder → 视频帧
|
||
```
|
||
|
||
---
|
||
|
||
## 2. 推理流程分析
|
||
|
||
### 2.1 阶段1: 生成动作 (sim_mode=False)
|
||
|
||
**目的**: 根据观测和指令生成动作序列
|
||
|
||
**输入**:
|
||
- `observation.images.top`: 历史图像观测 [B, C, T_obs, H, W]
|
||
- `observation.state`: 历史状态 [B, T_obs, state_dim]
|
||
- `action`: 历史动作 [B, T_action, action_dim]
|
||
- `instruction`: 文本指令
|
||
|
||
**输出**:
|
||
- `pred_videos`: 预测视频 [B, C, T, H, W]
|
||
- `pred_actions`: 预测动作序列 [B, T, action_dim]
|
||
|
||
**关键特点**:
|
||
- 动作条件被置零 (`cond_action_emb = torch.zeros_like(...)`)
|
||
- 使用文本指令作为主要引导
|
||
|
||
### 2.2 阶段2: 世界模型交互 (sim_mode=True)
|
||
|
||
**目的**: 根据动作预测未来观测
|
||
|
||
**输入**:
|
||
- `observation.images.top`: 当前图像
|
||
- `observation.state`: 当前状态
|
||
- `action`: 阶段1生成的动作序列
|
||
|
||
**输出**:
|
||
- `pred_videos`: 预测的未来视频帧
|
||
- `pred_states`: 预测的未来状态
|
||
|
||
**关键特点**:
|
||
- 不使用文本指令 (`text_input=False`)
|
||
- 动作条件被实际使用
|
||
|
||
---
|
||
|
||
## 3. 核心组件详解
|
||
|
||
### 3.1 DDIM采样器 (DDIMSampler)
|
||
|
||
**代码位置**: [src/unifolm_wma/models/samplers/ddim.py](src/unifolm_wma/models/samplers/ddim.py)
|
||
|
||
**核心方法**: `ddim_sampling()` (第168-300行)
|
||
|
||
**实际代码结构**:
|
||
```python
|
||
def ddim_sampling(self, cond, shape, x_T=None, ddim_steps=50, ...):
|
||
# 初始化
|
||
timesteps = self.ddim_timesteps[:ddim_steps]
|
||
x = x_T if x_T is not None else torch.randn(shape, device=device)
|
||
|
||
# 主循环
|
||
for i, step in enumerate(iterator):
|
||
# 获取时间步
|
||
index = total_steps - i - 1
|
||
ts = torch.full((b,), step, device=device, dtype=torch.long)
|
||
|
||
# 模型前向传播 (核心瓶颈)
|
||
outs = self.p_sample_ddim(x, cond, ts, index=index, ...)
|
||
x, pred_x0 = outs
|
||
|
||
return x
|
||
```
|
||
|
||
**性能数据** (来自 profiling 报告,`--ddim_steps 50`):
|
||
- DDIM采样调用: 22次 (action_generation + world_model_interaction 各11次)
|
||
- 单次采样(50步)平均耗时: 35.58s (总计 782.70s)
|
||
- 平均每步耗时: ~0.712s (35.58s / 50)
|
||
- 当前 `unconditional_guidance_scale=1.0` 时每步 1 次 UNet 前向;开启 CFG 时每步 2 次前向
|
||
|
||
### 3.2 DiffusionWrapper (条件路由器)
|
||
|
||
**代码位置**: [src/unifolm_wma/models/ddpms.py:2413-2524](src/unifolm_wma/models/ddpms.py)
|
||
|
||
**作用**: 将输入和条件路由到内部扩散模型
|
||
|
||
**实际代码** (第2469-2479行):
|
||
```python
|
||
elif self.conditioning_key == 'hybrid':
|
||
xc = torch.cat([x] + c_concat, dim=1) # 拼接latent条件
|
||
cc = torch.cat(c_crossattn, 1) # 拼接cross-attention条件
|
||
cc_action = c_crossattn_action
|
||
out = self.diffusion_model(xc, x_action, x_state, t,
|
||
context=cc, context_action=cc_action, **kwargs)
|
||
```
|
||
|
||
**条件类型**:
|
||
1. **c_concat**: 通道拼接条件 (VAE编码的图像)
|
||
2. **c_crossattn**: 交叉注意力条件 (文本、图像、状态、动作embedding)
|
||
3. **c_crossattn_action**: 动作头专用条件
|
||
|
||
### 3.3 WMAModel (核心扩散模型)
|
||
|
||
**代码位置**: [src/unifolm_wma/modules/networks/wma_model.py:326-849](src/unifolm_wma/modules/networks/wma_model.py)
|
||
|
||
**配置文件**: [configs/inference/world_model_interaction.yaml:69-104](configs/inference/world_model_interaction.yaml)
|
||
|
||
**实际配置参数**:
|
||
```yaml
|
||
in_channels: 8 # 输入通道 (4 latent + 4 VAE条件)
|
||
out_channels: 4 # 输出通道
|
||
model_channels: 320 # 基础通道数
|
||
channel_mult: [1, 2, 4, 4] # 通道倍增: [320, 640, 1280, 1280]
|
||
num_res_blocks: 2 # 每个分辨率2个ResBlock
|
||
attention_resolutions: [4, 2, 1] # 在这些分辨率启用注意力
|
||
num_head_channels: 64 # 每个注意力头64通道
|
||
transformer_depth: 1 # Transformer深度
|
||
context_dim: 1024 # 交叉注意力上下文维度
|
||
temporal_length: 16 # 时间序列长度
|
||
```
|
||
|
||
**架构层次** (详见附录A.1):
|
||
- 4个下采样阶段 (每阶段2个ResBlock + Attention)
|
||
- 1个中间块 (2个ResBlock + Attention)
|
||
- 3个上采样阶段 (每阶段2个ResBlock + Attention)
|
||
- 总计: 16个ResBlock + 32个Transformer
|
||
|
||
### 3.4 VAE (变分自编码器)
|
||
|
||
**代码位置**: [src/unifolm_wma/models/autoencoder.py](src/unifolm_wma/models/autoencoder.py)
|
||
|
||
**配置文件**: [configs/inference/world_model_interaction.yaml:159-180](configs/inference/world_model_interaction.yaml)
|
||
|
||
**实际配置参数**:
|
||
```yaml
|
||
AutoencoderKL:
|
||
embed_dim: 4 # Latent维度
|
||
z_channels: 4 # Latent通道数
|
||
in_channels: 3 # RGB输入
|
||
out_ch: 3 # RGB输出
|
||
ch: 128 # 基础通道数
|
||
ch_mult: [1, 2, 4, 4] # 通道倍增: [128, 256, 512, 512]
|
||
num_res_blocks: 2 # 每层2个ResBlock
|
||
attn_resolutions: [] # VAE中不使用注意力
|
||
```
|
||
|
||
**编码/解码过程**:
|
||
```python
|
||
# 编码: [B, 3, 320, 512] → [B, 4, 40, 64] (8×8下采样)
|
||
z = model.encode_first_stage(img)
|
||
|
||
# 解码: [B, 4, 40, 64] → [B, 3, 320, 512]
|
||
video = model.decode_first_stage(samples)
|
||
```
|
||
|
||
**性能数据**:
|
||
- VAE编码: 0.90s (22次, 平均0.041s/次)
|
||
- VAE解码: 12.44s (22次, 平均0.566s/次)
|
||
- 压缩比: 8×8 = 64倍空间压缩
|
||
|
||
**详细架构**: 见附录A.4
|
||
|
||
### 3.5 条件编码器
|
||
|
||
**性能说明**: 本次 profiling 未对各条件编码器单独计时,统一计入 `synthesis/conditioning_prep`,总计 2.92s (22次, 平均0.133s/次)。
|
||
|
||
#### 3.5.1 CLIP图像编码器
|
||
|
||
**代码位置**: [src/unifolm_wma/modules/encoders/condition.py](src/unifolm_wma/modules/encoders/condition.py) - `FrozenOpenCLIPImageEmbedderV2`
|
||
|
||
**配置文件**: [configs/inference/world_model_interaction.yaml:188-204](configs/inference/world_model_interaction.yaml)
|
||
|
||
**实际配置**:
|
||
```yaml
|
||
FrozenOpenCLIPImageEmbedderV2:
|
||
freeze: true
|
||
# 使用OpenCLIP ViT-H/14
|
||
# 输出: [B, 1280]
|
||
|
||
Resampler (图像投影器):
|
||
dim: 1024 # 输出维度
|
||
depth: 4 # Transformer深度
|
||
heads: 12 # 12个注意力头
|
||
num_queries: 16 # 16个查询token
|
||
embedding_dim: 1280 # CLIP输出维度
|
||
```
|
||
|
||
**数据流**: 图像 [B, 3, H, W] → CLIP → [B, 1280] → Resampler → [B, 16, 1024]
|
||
|
||
#### 3.5.2 文本编码器
|
||
|
||
**代码位置**: [src/unifolm_wma/modules/encoders/condition.py](src/unifolm_wma/modules/encoders/condition.py) - `FrozenOpenCLIPEmbedder`
|
||
|
||
**配置文件**: [configs/inference/world_model_interaction.yaml:182-186](configs/inference/world_model_interaction.yaml)
|
||
|
||
**实际配置**:
|
||
```yaml
|
||
FrozenOpenCLIPEmbedder:
|
||
freeze: True
|
||
layer: "penultimate" # 使用倒数第二层
|
||
# 输出: [B, seq_len, 1024]
|
||
```
|
||
|
||
#### 3.5.3 状态投影器
|
||
|
||
**代码位置**: [src/unifolm_wma/models/ddpms.py:2014-2026](src/unifolm_wma/models/ddpms.py) - `MLPProjector`
|
||
|
||
**MLPProjector实现** (src/unifolm_wma/utils/projector.py:14-37):
|
||
```python
|
||
class MLPProjector(nn.Module):
|
||
def __init__(self, input_dim: int, output_dim: int, mlp_type: str = "gelu-mlp"):
|
||
if mlp_type == "gelu-mlp":
|
||
self.projector = nn.Sequential(
|
||
nn.Linear(input_dim, output_dim, bias=True),
|
||
nn.GELU(approximate='tanh'),
|
||
nn.Linear(output_dim, output_dim, bias=True),
|
||
)
|
||
```
|
||
|
||
**数据流**: 状态 [B, T_obs, 16] → MLPProjector → [B, T_obs, 1024] + agent_state_pos_emb
|
||
|
||
#### 3.5.4 动作投影器
|
||
|
||
**代码位置**: [src/unifolm_wma/models/ddpms.py:2020-2024](src/unifolm_wma/models/ddpms.py) - `MLPProjector`
|
||
|
||
**数据流**: 动作 [B, T_action, 16] → MLPProjector → [B, T_action, 1024] + agent_action_pos_emb
|
||
|
||
**位置嵌入定义**:
|
||
```python
|
||
# ddpms.py:2023-2026
|
||
self.agent_action_pos_emb = nn.Parameter(torch.randn(1, 16, 1024))
|
||
self.agent_state_pos_emb = nn.Parameter(torch.randn(1, n_obs_steps, 1024))
|
||
```
|
||
|
||
---
|
||
|
||
## 4. 性能瓶颈分析
|
||
|
||
### 4.1 时间分布 (profiling 报告, 11次迭代 / 22次采样)
|
||
|
||
说明: `profile_section` 存在嵌套,宏观统计的总和不是 wall time,以下以每段的 total/avg 为准。
|
||
|
||
| Section | Count | Total(s) | Avg(s) | 说明 |
|
||
|---------|-------|----------|--------|------|
|
||
| iteration_total | 11 | 836.79 | 76.07 | 单次迭代总耗时 |
|
||
| action_generation | 11 | 399.71 | 36.34 | 生成动作 (DDIM 50步) |
|
||
| world_model_interaction | 11 | 398.38 | 36.22 | 世界模型交互 (DDIM 50步) |
|
||
| synthesis/ddim_sampling | 22 | 782.70 | 35.58 | 单次采样 |
|
||
| synthesis/conditioning_prep | 22 | 2.92 | 0.13 | 条件编码汇总 |
|
||
| synthesis/decode_first_stage | 22 | 12.44 | 0.57 | VAE解码 |
|
||
| save_results | 11 | 38.67 | 3.52 | I/O保存 |
|
||
| model_loading/config | 1 | 49.77 | 49.77 | 一次性开销 |
|
||
| model_loading/checkpoint | 1 | 11.83 | 11.83 | 一次性开销 |
|
||
| model_to_cuda | 1 | 8.91 | 8.91 | 一次性开销 |
|
||
|
||
### 4.2 DDIM采样详细分析
|
||
|
||
**DDIM采样是主要瓶颈** (基于 50 步采样):
|
||
|
||
- 采样调用次数: 22 次 (11 次迭代 × 2 阶段)
|
||
- 采样总耗时: 782.70s,平均 35.58s/次
|
||
- 平均每步耗时: ~0.712s (35.58s / 50)
|
||
- `unconditional_guidance_scale=1.0` 时每步 1 次 UNet 前向;开启 CFG 时每步 2 次前向
|
||
- 在 `conditioning_prep + ddim_sampling + decode_first_stage` 中,ddim_sampling 占约 98%
|
||
|
||
### 4.3 瓶颈总结
|
||
|
||
**关键发现**:
|
||
1. **DDIM采样占比最高** - 单次迭代平均 76.07s,其中采样约 71.15s (≈93%)
|
||
2. **CUDA算子时间主要集中在 Linear/GEMM(29.8%) 与 Convolution(13.9%)**;Attention 约 3.0%
|
||
3. **CPU侧 copy/to 仍明显** (`aten::copy_`, `aten::to/_to_copy` 在报告中耗时靠前)
|
||
4. VAE解码为次级瓶颈 (0.57s/次)
|
||
|
||
### 4.4 GPU显存概览
|
||
|
||
- Peak allocated: 17890.50 MB
|
||
- Average allocated: 16129.98 MB
|
||
|
||
---
|
||
|
||
## 5. 内核融合优化建议
|
||
|
||
### 5.1 优化策略概览
|
||
|
||
基于性能分析,优化应聚焦于:
|
||
1. **UNet3D模型前向传播** (最高优先级)
|
||
2. **VAE解码器** (次要优先级)
|
||
3. **批处理和并行化** (辅助优化)
|
||
|
||
### 5.2 WMAModel内核融合机会
|
||
|
||
#### 5.2.1 时间步嵌入融合
|
||
|
||
**代码位置**: [src/unifolm_wma/utils/diffusion.py](src/unifolm_wma/utils/diffusion.py) - `timestep_embedding()`
|
||
|
||
**当前实现** (实际代码):
|
||
```python
|
||
# 1. 正弦位置编码
|
||
def timestep_embedding(timesteps, dim, max_period=10000):
|
||
half = dim // 2
|
||
freqs = torch.exp(-math.log(max_period) * torch.arange(0, half) / half)
|
||
args = timesteps[:, None].float() * freqs[None]
|
||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||
return embedding
|
||
|
||
# 2. 时间嵌入网络 (在WMAModel.__init__中)
|
||
self.time_embed = nn.Sequential(
|
||
nn.Linear(model_channels, time_embed_dim), # 320 → 1280
|
||
nn.SiLU(),
|
||
nn.Linear(time_embed_dim, time_embed_dim), # 1280 → 1280
|
||
)
|
||
```
|
||
|
||
**融合机会**:
|
||
- `Linear + SiLU + Linear` 可融合为单个kernel
|
||
- 正弦编码计算可与第一个Linear融合
|
||
|
||
**预期收益**: 减少2-3次kernel启动开销
|
||
|
||
#### 5.2.2 ResBlock内核融合
|
||
|
||
**代码位置**: [src/unifolm_wma/modules/networks/wma_model.py:130-263](src/unifolm_wma/modules/networks/wma_model.py) - `class ResBlock`
|
||
|
||
**当前实现** (实际代码):
|
||
```python
|
||
# in_layers: GroupNorm + SiLU + Conv
|
||
self.in_layers = nn.Sequential(
|
||
normalization(channels), # GroupNorm
|
||
nn.SiLU(),
|
||
conv_nd(dims, channels, out_channels, 3, padding=1)
|
||
)
|
||
|
||
# emb_layers: SiLU + Linear
|
||
self.emb_layers = nn.Sequential(
|
||
nn.SiLU(),
|
||
nn.Linear(emb_channels, out_channels)
|
||
)
|
||
|
||
# out_layers: GroupNorm + SiLU + Dropout + Conv
|
||
self.out_layers = nn.Sequential(
|
||
normalization(out_channels), # GroupNorm
|
||
nn.SiLU(),
|
||
nn.Dropout(p=dropout),
|
||
zero_module(nn.Conv2d(out_channels, out_channels, 3, padding=1))
|
||
)
|
||
```
|
||
|
||
**融合机会**:
|
||
1. `GroupNorm + SiLU` 可融合 (in_layers和out_layers各一次)
|
||
2. `emb_layers` 的 `SiLU + Linear` 可融合
|
||
3. 残差加法可与下一层的GroupNorm融合
|
||
|
||
**实际瓶颈**: 16个ResBlock × 50步 × 2次 = **1600次ResBlock调用**
|
||
|
||
**预期收益**: 每个ResBlock节省50-60%的kernel启动开销
|
||
|
||
#### 5.2.3 注意力机制优化
|
||
|
||
**代码位置**: [src/unifolm_wma/modules/attention.py](src/unifolm_wma/modules/attention.py)
|
||
|
||
**实际配置**:
|
||
- SpatialTransformer: 空间维度注意力
|
||
- TemporalTransformer: 时间维度注意力
|
||
- 总计: 32个Transformer × 50步 × 2次 = **3200次注意力调用**
|
||
|
||
**优化方案**:
|
||
使用 PyTorch 内置的 Flash Attention:
|
||
```python
|
||
from torch.nn.functional import scaled_dot_product_attention
|
||
|
||
# 替换标准注意力计算
|
||
out = scaled_dot_product_attention(Q, K, V, is_causal=False)
|
||
```
|
||
|
||
**预期收益**: 注意力层加速2-3倍,整体加速30-40%
|
||
|
||
### 5.3 VAE解码器优化
|
||
|
||
**代码位置**: [src/unifolm_wma/models/autoencoder.py](src/unifolm_wma/models/autoencoder.py)
|
||
|
||
**当前性能**: 12.44s (22次调用, 平均0.566s/次)
|
||
|
||
**优化方案**:
|
||
1. **混合精度**: 使用FP16进行解码
|
||
```python
|
||
with torch.cuda.amp.autocast():
|
||
video = vae.decode(latent)
|
||
```
|
||
|
||
2. **批处理优化**: 确保VAE解码使用批处理而非逐帧
|
||
|
||
**预期收益**: 加速20-30%
|
||
|
||
### 5.4 实施建议
|
||
|
||
#### 5.4.1 使用 torch.compile() (最简单)
|
||
|
||
**代码位置**: [scripts/evaluation/world_model_interaction.py](scripts/evaluation/world_model_interaction.py)
|
||
|
||
**实际实施位置**:
|
||
在模型加载后添加:
|
||
|
||
```python
|
||
# 在模型加载并移动到GPU后添加
|
||
config = OmegaConf.load(args.config)
|
||
model = instantiate_from_config(config.model)
|
||
model = load_model_checkpoint(model, args.ckpt_path)
|
||
model.eval()
|
||
model = model.cuda()
|
||
|
||
# 添加 torch.compile() 优化
|
||
model.model.diffusion_model = torch.compile(
|
||
model.model.diffusion_model,
|
||
mode='max-autotune', # 或 'reduce-overhead'
|
||
fullgraph=True
|
||
)
|
||
```
|
||
|
||
**优点**:
|
||
- 无需修改模型代码
|
||
- 自动融合操作
|
||
- 支持动态shape
|
||
|
||
**预期收益**: 20-40%加速
|
||
|
||
#### 5.4.2 使用 Flash Attention
|
||
|
||
**代码位置**: [src/unifolm_wma/modules/attention.py](src/unifolm_wma/modules/attention.py)
|
||
|
||
**当前实现分析**:
|
||
代码已经支持 xformers (`xformers.ops.memory_efficient_attention`)。当 xformers 可用时,`CrossAttention` 类会自动使用 `efficient_forward` 方法:
|
||
|
||
```python
|
||
# attention.py:90-91
|
||
if XFORMERS_IS_AVAILBLE and temporal_length is None:
|
||
self.forward = self.efficient_forward
|
||
```
|
||
|
||
**进一步优化方案**:
|
||
如果 xformers 不可用,可以使用 PyTorch 内置的 Flash Attention:
|
||
```python
|
||
from torch.nn.functional import scaled_dot_product_attention
|
||
out = scaled_dot_product_attention(q, k, v, is_causal=False)
|
||
```
|
||
|
||
**预期收益**: 注意力层加速2-3倍
|
||
|
||
|
||
#### 5.4.3 混合精度推理
|
||
|
||
**代码位置**: [scripts/evaluation/world_model_interaction.py](scripts/evaluation/world_model_interaction.py)
|
||
|
||
**实际实施位置**:
|
||
在推理调用处添加混合精度上下文:
|
||
|
||
```python
|
||
# 在 image_guided_synthesis_sim_mode 调用处添加
|
||
with torch.cuda.amp.autocast():
|
||
pred_videos, pred_actions, pred_states = image_guided_synthesis_sim_mode(
|
||
model, sample['instruction'], observation, noise_shape,
|
||
action_cond_step=args.exe_steps,
|
||
ddim_steps=args.ddim_steps, ddim_eta=args.ddim_eta,
|
||
unconditional_guidance_scale=args.unconditional_guidance_scale,
|
||
fs=model_input_fs, timestep_spacing=args.timestep_spacing,
|
||
guidance_rescale=args.guidance_rescale, sim_mode=False)
|
||
```
|
||
|
||
**注意事项**:
|
||
- 模型会自动在FP16和FP32之间切换
|
||
- 某些操作(如LayerNorm)会自动保持FP32精度
|
||
- 无需手动转换模型权重
|
||
|
||
**预期收益**: 30-50%加速 + 减少50%显存
|
||
|
||
### 5.5 优化路线图
|
||
|
||
#### 阶段1: 快速优化
|
||
|
||
**目标**: 获得20-40%加速,无需修改模型代码
|
||
|
||
**实施步骤**:
|
||
1. 启用 `torch.compile()` - 在模型加载后添加
|
||
2. 启用 `torch.backends.cudnn.benchmark = True` - 在推理开始前设置
|
||
3. 使用混合精度推理 (FP16) - 在推理调用处添加
|
||
|
||
**实施代码**:
|
||
```python
|
||
# 在推理函数开始处添加
|
||
torch.backends.cudnn.benchmark = True
|
||
|
||
# 在模型加载后添加 torch.compile()
|
||
model = model.cuda()
|
||
model.model.diffusion_model = torch.compile(
|
||
model.model.diffusion_model,
|
||
mode='max-autotune'
|
||
)
|
||
|
||
# 在推理循环中使用混合精度
|
||
with torch.cuda.amp.autocast():
|
||
pred_videos, pred_actions, pred_states = image_guided_synthesis_sim_mode(...)
|
||
```
|
||
|
||
#### 阶段2: 中级优化
|
||
|
||
**目标**: 获得50-70%加速
|
||
|
||
**实施步骤**:
|
||
1. 确保 xformers 已安装并启用 - 检查 `XFORMERS_IS_AVAILBLE` 标志
|
||
2. 优化VAE解码器批处理 - 检查 [src/unifolm_wma/models/autoencoder.py](src/unifolm_wma/models/autoencoder.py) 中的 `decode()` 方法
|
||
3. 分析并优化内存访问模式 - 使用 `torch.cuda.memory_stats()` 分析
|
||
|
||
**关键修改点**:
|
||
- 确认 xformers 已正确安装: `pip install xformers`
|
||
- 在 `CrossAttention` 类中,当 xformers 可用时会自动使用 `efficient_forward`
|
||
- 确保VAE解码使用批处理而非逐帧处理
|
||
|
||
#### 阶段3: 深度优化
|
||
|
||
**目标**: 获得2-3倍加速
|
||
|
||
**实施步骤**:
|
||
1. 自定义CUDA kernel融合关键操作
|
||
- 融合 GroupNorm + SiLU + Conv (在 [src/unifolm_wma/modules/networks/wma_model.py:130-263](src/unifolm_wma/modules/networks/wma_model.py) 的 ResBlock 中)
|
||
2. 优化卷积操作
|
||
- 分析 Conv2D 操作的性能 (模型实际使用 Conv2D 而非 Conv3D)
|
||
3. 优化数据加载和预处理pipeline
|
||
- 检查 [scripts/evaluation/world_model_interaction.py](scripts/evaluation/world_model_interaction.py) 的数据准备部分
|
||
|
||
**需要的技能**:
|
||
- CUDA编程
|
||
- PyTorch C++扩展
|
||
- 性能分析工具 (Nsight Systems, nvprof)
|
||
|
||
|
||
### 5.6 预期总体收益
|
||
|
||
基于以上优化策略和实际代码分析,预期性能提升:
|
||
|
||
| 优化阶段 | 预期加速比 | 实施难度 | 主要修改文件 |
|
||
|---------|-----------|---------|-------------|
|
||
| 阶段1: 快速优化 | 1.2-1.4x | 低 | [scripts/evaluation/world_model_interaction.py](scripts/evaluation/world_model_interaction.py) |
|
||
| 阶段2: 中级优化 | 1.5-1.7x | 中 | [src/unifolm_wma/modules/attention.py](src/unifolm_wma/modules/attention.py), [src/unifolm_wma/models/autoencoder.py](src/unifolm_wma/models/autoencoder.py) |
|
||
| 阶段3: 深度优化 | 2.0-3.0x | 高 | [src/unifolm_wma/modules/networks/wma_model.py](src/unifolm_wma/modules/networks/wma_model.py) + 自定义CUDA kernel |
|
||
|
||
**总体目标**: 通过系统性优化实现 2-3倍加速
|
||
|
||
---
|
||
|
||
## 6. 关键代码位置索引
|
||
|
||
为方便内核融合实施,以下是关键代码位置:
|
||
|
||
### 6.1 核心模型文件
|
||
|
||
| 组件 | 文件路径 | 关键类/函数 |
|
||
|------|---------|-----------|
|
||
| DDPM主模型 | `src/unifolm_wma/models/ddpms.py` | `class DDPM` |
|
||
| 条件包装器 | `src/unifolm_wma/models/ddpms.py:2413` | `class DiffusionWrapper` |
|
||
| DDIM采样器 | `src/unifolm_wma/models/samplers/ddim.py` | `class DDIMSampler` |
|
||
| VAE编解码 | `src/unifolm_wma/models/autoencoder.py` | `encode_first_stage`, `decode_first_stage` |
|
||
|
||
### 6.2 推理脚本
|
||
|
||
| 文件 | 说明 |
|
||
|------|------|
|
||
| `scripts/evaluation/world_model_interaction.py` | 推理脚本 |
|
||
|
||
### 6.3 配置文件
|
||
|
||
| 文件 | 说明 |
|
||
|------|------|
|
||
| `configs/inference/world_model_interaction.yaml` | 推理配置 |
|
||
| `unitree_g1_pack_camera/case1/run_world_model_interaction.sh` | 运行脚本 |
|
||
|
||
|
||
---
|
||
|
||
## 7. 下一步行动建议
|
||
|
||
### 7.1 立即可执行的优化
|
||
|
||
**最小改动,最大收益**:
|
||
|
||
1. **启用 torch.compile()**
|
||
- **代码位置**: [scripts/evaluation/world_model_interaction.py](scripts/evaluation/world_model_interaction.py)
|
||
- **修改**: 在模型加载后添加
|
||
```python
|
||
# 在模型加载并移动到GPU后添加
|
||
model.model.diffusion_model = torch.compile(
|
||
model.model.diffusion_model,
|
||
mode='max-autotune'
|
||
)
|
||
```
|
||
|
||
2. **启用 cuDNN benchmark**
|
||
- **代码位置**: [scripts/evaluation/world_model_interaction.py](scripts/evaluation/world_model_interaction.py)
|
||
- **修改**: 在推理函数开始处添加
|
||
```python
|
||
torch.backends.cudnn.benchmark = True
|
||
```
|
||
|
||
3. **混合精度推理**
|
||
- **代码位置**: [scripts/evaluation/world_model_interaction.py](scripts/evaluation/world_model_interaction.py)
|
||
- **修改**: 在 `image_guided_synthesis_sim_mode` 调用处添加
|
||
```python
|
||
with torch.cuda.amp.autocast():
|
||
pred_videos, pred_actions, pred_states = image_guided_synthesis_sim_mode(...)
|
||
```
|
||
|
||
**预期收益**: 20-40%加速,无风险
|
||
|
||
|
||
### 7.2 需要深入探索的部分
|
||
|
||
为了更精确的优化,建议进一步分析:
|
||
|
||
1. **注意力层的具体实现**
|
||
- **代码位置**: [src/unifolm_wma/modules/attention.py](src/unifolm_wma/modules/attention.py)
|
||
- **分析目标**:
|
||
- `CrossAttention` 类 (第48-398行) - 核心注意力实现
|
||
- `BasicTransformerBlock` 类 (第400-469行) - Transformer块
|
||
- 确认 xformers 是否已启用 (`XFORMERS_IS_AVAILBLE` 标志)
|
||
|
||
2. **ResBlock的详细结构**
|
||
- **代码位置**: [src/unifolm_wma/modules/networks/wma_model.py:130-263](src/unifolm_wma/modules/networks/wma_model.py)
|
||
- **分析目标**:
|
||
- 确认 GroupNorm + SiLU + Conv 的调用顺序
|
||
- 识别可以融合的操作序列
|
||
- 评估自定义 CUDA kernel 的可行性
|
||
|
||
3. **内存瓶颈分析**
|
||
- **分析工具**: 使用 `torch.cuda.memory_stats()` 和 `torch.profiler`
|
||
- **分析位置**: [scripts/evaluation/world_model_interaction.py](scripts/evaluation/world_model_interaction.py)
|
||
- **分析目标**:
|
||
- 识别内存拷贝热点
|
||
- 优化中间张量的生命周期
|
||
- 减少不必要的内存分配
|
||
|
||
4. **计算瓶颈定位**
|
||
- **分析工具**: Nsight Systems 或 PyTorch Profiler
|
||
- **分析目标**:
|
||
- 识别 kernel 启动开销
|
||
- 分析 GPU 利用率
|
||
- 找到计算密集型操作
|
||
|
||
---
|
||
|
||
## 8. 参考资料
|
||
|
||
### 8.1 优化技术文档
|
||
|
||
- [PyTorch 2.0 torch.compile()](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html)
|
||
- [Flash Attention](https://github.com/Dao-AILab/flash-attention)
|
||
- [CUDA Kernel Fusion](https://developer.nvidia.com/blog/cuda-pro-tip-increase-performance-with-vectorized-memory-access/)
|
||
|
||
### 8.2 相关论文
|
||
|
||
- DDIM: Denoising Diffusion Implicit Models
|
||
- Flash Attention: Fast and Memory-Efficient Exact Attention
|
||
- Efficient Diffusion Models for Vision
|
||
|
||
---
|
||
|
||
## 9. 总结
|
||
|
||
### 9.1 关键发现
|
||
|
||
1. **DDIM采样仍是主要瓶颈** - 单次采样(50步)平均 35.58s
|
||
2. **Linear/GEMM 与 Convolution 为主要 CUDA 时间来源** - Attention 占比相对较小
|
||
3. **VAE解码为次级优化目标** - 0.57s/次
|
||
|
||
### 9.2 优化优先级
|
||
|
||
**高优先级** (立即实施):
|
||
- ✅ torch.compile()
|
||
- ✅ cuDNN benchmark
|
||
- ✅ 混合精度推理
|
||
|
||
**中优先级** (1周内):
|
||
- Flash Attention集成
|
||
- VAE批处理优化
|
||
|
||
**低优先级** (长期):
|
||
- 自定义CUDA kernel
|
||
- 模型架构改进
|
||
|
||
### 9.3 预期成果
|
||
|
||
通过系统性优化,预期可获得 **1.5-3倍加速** (视采样步数与编译/混合精度策略而定)。
|
||
|
||
---
|
||
|
||
**文档版本**: v1.2
|
||
**创建日期**: 2026-01-17
|
||
**最后更新**: 2026-01-18
|
||
**更新内容**: 校准DDIM步数为50并替换为最新profiling数据
|
||
|
||
|
||
---
|
||
|
||
## 附录A: 实际模型架构详解
|
||
|
||
基于代码分析,以下是真实的模型实现细节。
|
||
|
||
### A.1 WMAModel 实际配置
|
||
|
||
**配置文件**: `configs/inference/world_model_interaction.yaml:69-104`
|
||
|
||
```yaml
|
||
WMAModel参数:
|
||
in_channels: 8 # 输入通道 (4 latent + 4 concat条件)
|
||
out_channels: 4 # 输出通道 (latent空间)
|
||
model_channels: 320 # 基础通道数
|
||
channel_mult: [1, 2, 4, 4] # 通道倍增: [320, 640, 1280, 1280]
|
||
num_res_blocks: 2 # 每个分辨率2个ResBlock
|
||
attention_resolutions: [4, 2, 1] # 在这些分辨率启用注意力
|
||
num_head_channels: 64 # 每个注意力头64通道
|
||
transformer_depth: 1 # Transformer深度
|
||
context_dim: 1024 # 交叉注意力上下文维度
|
||
temporal_length: 16 # 时间序列长度
|
||
dropout: 0.1
|
||
```
|
||
|
||
**架构层次**:
|
||
```
|
||
输入: [B, 8, 16, 40, 64] (8通道 = 4 latent + 4 VAE条件)
|
||
↓
|
||
下采样路径 (4个阶段):
|
||
Stage 0: [B, 320, 16, 40, 64] - 2个ResBlock + SpatialTransformer + TemporalTransformer
|
||
Stage 1: [B, 640, 16, 20, 32] - Downsample + 2个ResBlock + Attention
|
||
Stage 2: [B, 1280, 16, 10, 16] - Downsample + 2个ResBlock + Attention
|
||
Stage 3: [B, 1280, 16, 5, 8] - Downsample + 2个ResBlock + Attention
|
||
↓
|
||
中间块: [B, 1280, 16, 5, 8]
|
||
- ResBlock + SpatialTransformer + TemporalTransformer + ResBlock
|
||
↓
|
||
上采样路径 (3个阶段):
|
||
Stage 2: [B, 1280, 16, 10, 16] - Upsample + 2个ResBlock + Attention
|
||
Stage 1: [B, 640, 16, 20, 32] - Upsample + 2个ResBlock + Attention
|
||
Stage 0: [B, 320, 16, 40, 64] - Upsample + 2个ResBlock + Attention
|
||
↓
|
||
输出: [B, 4, 16, 40, 64] (预测的噪声或速度)
|
||
```
|
||
|
||
|
||
### A.2 ResBlock 实际实现
|
||
|
||
**位置**: `src/unifolm_wma/modules/networks/wma_model.py:130-263`
|
||
|
||
**实际代码结构**:
|
||
```python
|
||
class ResBlock:
|
||
def __init__(self, channels, emb_channels, dropout, ...):
|
||
# 输入层: GroupNorm + SiLU + Conv
|
||
self.in_layers = nn.Sequential(
|
||
normalization(channels), # GroupNorm
|
||
nn.SiLU(), # 激活函数
|
||
conv_nd(dims, channels, out_channels, 3, padding=1)
|
||
)
|
||
|
||
# 时间步嵌入层: SiLU + Linear
|
||
self.emb_layers = nn.Sequential(
|
||
nn.SiLU(),
|
||
nn.Linear(emb_channels, out_channels)
|
||
)
|
||
|
||
# 输出层: GroupNorm + SiLU + Dropout + Conv
|
||
self.out_layers = nn.Sequential(
|
||
normalization(out_channels), # GroupNorm
|
||
nn.SiLU(),
|
||
nn.Dropout(p=dropout),
|
||
zero_module(nn.Conv2d(out_channels, out_channels, 3, padding=1))
|
||
)
|
||
|
||
# 残差连接
|
||
self.skip_connection = ...
|
||
|
||
# 可选的时间卷积
|
||
if use_temporal_conv:
|
||
self.temporal_conv = TemporalConvBlock(...)
|
||
|
||
def forward(self, x, emb):
|
||
h = self.in_layers(x) # GroupNorm + SiLU + Conv
|
||
emb_out = self.emb_layers(emb) # 时间步嵌入
|
||
h = h + emb_out # 加入时间步信息
|
||
h = self.out_layers(h) # GroupNorm + SiLU + Dropout + Conv
|
||
h = self.skip_connection(x) + h # 残差连接
|
||
|
||
if use_temporal_conv:
|
||
h = self.temporal_conv(h) # 时间维度卷积
|
||
return h
|
||
```
|
||
|
||
**内核融合机会**:
|
||
1. `GroupNorm + SiLU` 可融合 (in_layers和out_layers各一次)
|
||
2. `emb_layers(SiLU + Linear)` 可融合
|
||
3. `残差加法 + 下一层的GroupNorm` 可融合
|
||
|
||
|
||
### A.3 注意力机制实际实现
|
||
|
||
**SpatialTransformer** (空间注意力):
|
||
- 位置: `src/unifolm_wma/modules/attention.py:472-558`
|
||
- 在特征图的空间维度(H×W)上执行自注意力和交叉注意力
|
||
- 使用 `transformer_depth=1`,即每个位置1层Transformer
|
||
- 当 xformers 可用时,使用 `efficient_forward` 方法进行高效注意力计算
|
||
|
||
**TemporalTransformer** (时间注意力):
|
||
- 位置: `src/unifolm_wma/modules/attention.py:561-680`
|
||
- 在时间维度(T=16帧)上执行自注意力
|
||
- 配置: `temporal_selfatt_only=True` (仅时间自注意力,不做交叉注意力)
|
||
- 使用相对位置编码: `use_relative_position=False` (实际未启用)
|
||
|
||
**CrossAttention** (核心注意力层):
|
||
- 位置: `src/unifolm_wma/modules/attention.py:48-398`
|
||
- 支持多种交叉注意力: 图像、文本、状态、动作
|
||
- 当 xformers 可用时自动使用 `xformers.ops.memory_efficient_attention`
|
||
|
||
**注意力头配置**:
|
||
- `num_head_channels=64`: 每个头64通道
|
||
- 对于320通道: 320/64 = 5个注意力头
|
||
- 对于640通道: 640/64 = 10个注意力头
|
||
- 对于1280通道: 1280/64 = 20个注意力头
|
||
|
||
|
||
### A.4 VAE 实际配置
|
||
|
||
**配置文件**: `configs/inference/world_model_interaction.yaml:159-180`
|
||
|
||
```yaml
|
||
AutoencoderKL:
|
||
embed_dim: 4 # Latent维度
|
||
z_channels: 4 # Latent通道数
|
||
resolution: 256 # 基础分辨率
|
||
in_channels: 3 # RGB输入
|
||
out_ch: 3 # RGB输出
|
||
ch: 128 # 基础通道数
|
||
ch_mult: [1, 2, 4, 4] # 通道倍增
|
||
num_res_blocks: 2 # 每层2个ResBlock
|
||
attn_resolutions: [] # VAE中不使用注意力
|
||
dropout: 0.0
|
||
```
|
||
|
||
**编码器架构**:
|
||
```
|
||
输入: [B, 3, 320, 512]
|
||
↓ Conv 3→128
|
||
↓ ResBlock×2 [128, 320, 512]
|
||
↓ Downsample [128, 160, 256]
|
||
↓ ResBlock×2 [256, 160, 256]
|
||
↓ Downsample [256, 80, 128]
|
||
↓ ResBlock×2 [512, 80, 128]
|
||
↓ Downsample [512, 40, 64]
|
||
↓ ResBlock×2 [512, 40, 64]
|
||
↓ ResBlock + Conv
|
||
输出: [B, 4, 40, 64] (8×8下采样)
|
||
```
|
||
|
||
**解码器架构** (编码器的镜像):
|
||
```
|
||
输入: [B, 4, 40, 64]
|
||
↓ Conv + ResBlock
|
||
↓ ResBlock×2 [512, 40, 64]
|
||
↓ Upsample [512, 80, 128]
|
||
↓ ResBlock×2 [512, 80, 128]
|
||
↓ Upsample [256, 160, 256]
|
||
↓ ResBlock×2 [256, 160, 256]
|
||
↓ Upsample [128, 320, 512]
|
||
↓ ResBlock×2 [128, 320, 512]
|
||
↓ Conv 128→3
|
||
输出: [B, 3, 320, 512]
|
||
```
|
||
|
||
|
||
### A.5 条件编码器实际配置
|
||
|
||
#### CLIP图像编码器
|
||
|
||
**配置**: `configs/inference/world_model_interaction.yaml:188-191`
|
||
```yaml
|
||
FrozenOpenCLIPImageEmbedderV2:
|
||
freeze: true
|
||
# 使用OpenCLIP的ViT-H/14模型
|
||
# 输出维度: 1280
|
||
```
|
||
|
||
**图像投影器 (Resampler)**:
|
||
```yaml
|
||
Resampler:
|
||
dim: 1024 # 输出维度
|
||
depth: 4 # Transformer深度
|
||
dim_head: 64 # 注意力头维度
|
||
heads: 12 # 12个注意力头
|
||
num_queries: 16 # 16个查询token
|
||
embedding_dim: 1280 # CLIP输出维度
|
||
output_dim: 1024 # 最终输出维度
|
||
video_length: 16 # 视频长度
|
||
```
|
||
|
||
**数据流**:
|
||
```
|
||
图像 [B, 3, H, W]
|
||
↓ CLIP Encoder
|
||
↓ [B, 1280]
|
||
↓ Resampler (Perceiver-style)
|
||
↓ [B, 16, 1024] (16个token,每个1024维)
|
||
```
|
||
|
||
|
||
#### 文本编码器
|
||
|
||
**配置**: `configs/inference/world_model_interaction.yaml:182-186`
|
||
```yaml
|
||
FrozenOpenCLIPEmbedder:
|
||
freeze: True
|
||
layer: "penultimate" # 使用倒数第二层
|
||
# 输出维度: 1024
|
||
```
|
||
|
||
**数据流**:
|
||
```
|
||
文本指令 "pick up the box"
|
||
↓ OpenCLIP Text Encoder
|
||
↓ [B, seq_len, 1024]
|
||
```
|
||
|
||
#### 动作/状态投影器
|
||
|
||
**代码位置**: [src/unifolm_wma/models/ddpms.py:2014-2026](src/unifolm_wma/models/ddpms.py)
|
||
|
||
**MLPProjector实现** (src/unifolm_wma/utils/projector.py:14-37):
|
||
```python
|
||
class MLPProjector(nn.Module):
|
||
def __init__(self, input_dim: int, output_dim: int, mlp_type: str = "gelu-mlp"):
|
||
if mlp_type == "gelu-mlp":
|
||
self.projector = nn.Sequential(
|
||
nn.Linear(input_dim, output_dim, bias=True),
|
||
nn.GELU(approximate='tanh'),
|
||
nn.Linear(output_dim, output_dim, bias=True),
|
||
)
|
||
```
|
||
|
||
**初始化代码** (ddpms.py:2014-2026):
|
||
```python
|
||
# 状态投影器
|
||
self.state_projector = MLPProjector(agent_state_dim, 1024) # 16 → 1024
|
||
self.action_projector = MLPProjector(agent_action_dim, 1024) # 16 → 1024
|
||
|
||
# 位置嵌入
|
||
self.agent_action_pos_emb = nn.Parameter(torch.randn(1, 16, 1024))
|
||
self.agent_state_pos_emb = nn.Parameter(torch.randn(1, n_obs_steps, 1024))
|
||
```
|
||
|
||
**数据流**:
|
||
```
|
||
状态 [B, T_obs, 16]
|
||
↓ MLPProjector (Linear + GELU + Linear)
|
||
↓ [B, T_obs, 1024]
|
||
↓ + agent_state_pos_emb
|
||
↓ [B, T_obs, 1024]
|
||
|
||
动作 [B, T_action, 16]
|
||
↓ MLPProjector (Linear + GELU + Linear)
|
||
↓ [B, T_action, 1024]
|
||
↓ + agent_action_pos_emb
|
||
↓ [B, T_action, 1024]
|
||
```
|
||
|
||
|
||
### A.6 时间步嵌入实际实现
|
||
|
||
**位置**: `src/unifolm_wma/utils/diffusion.py:timestep_embedding`
|
||
|
||
**实际代码**:
|
||
```python
|
||
def timestep_embedding(timesteps, dim, max_period=10000):
|
||
"""
|
||
创建正弦位置编码
|
||
"""
|
||
half = dim // 2
|
||
freqs = torch.exp(
|
||
-math.log(max_period) * torch.arange(start=0, end=half) / half
|
||
)
|
||
args = timesteps[:, None].float() * freqs[None]
|
||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||
return embedding
|
||
```
|
||
|
||
**在WMAModel中的使用**:
|
||
```python
|
||
# 时间步 t ∈ [0, 999]
|
||
t_emb = timestep_embedding(t, model_channels) # [B, 320]
|
||
t_emb = self.time_embed(t_emb) # Linear(320 → 1280)
|
||
# 输出: [B, 1280]
|
||
```
|
||
|
||
**时间嵌入网络**:
|
||
```
|
||
t ∈ [0, 999]
|
||
↓ timestep_embedding (正弦编码)
|
||
↓ [B, 320]
|
||
↓ Linear(320 → 1280)
|
||
↓ SiLU
|
||
↓ Linear(1280 → 1280)
|
||
↓ [B, 1280]
|
||
```
|
||
|
||
|
||
### A.7 动作头 (ConditionalUnet1D) 实际配置
|
||
|
||
**配置**: `configs/inference/world_model_interaction.yaml:106-127`
|
||
|
||
```yaml
|
||
ConditionalUnet1D:
|
||
input_dim: 16 # 动作维度
|
||
n_obs_steps: 2 # 观测步数
|
||
diffusion_step_embed_dim: 128 # 扩散步嵌入维度
|
||
down_dims: [256, 512, 1024, 2048] # 下采样通道
|
||
kernel_size: 5 # 卷积核大小
|
||
n_groups: 8 # GroupNorm分组数
|
||
horizon: 16 # 预测时间范围
|
||
use_linear_attn: true # 使用线性注意力
|
||
imagen_cond_gradient: true # 使用图像条件梯度
|
||
```
|
||
|
||
**架构**:
|
||
```
|
||
输入: 噪声动作 [B, 16, 16] (16维动作 × 16步)
|
||
条件:
|
||
- 图像特征 [B, C, H, W] 来自WMAModel中间层
|
||
- 观测编码 [B, n_obs, obs_dim]
|
||
|
||
↓ Conv1D + ResBlock
|
||
↓ [B, 256, 16]
|
||
↓ Downsample + ResBlock
|
||
↓ [B, 512, 8]
|
||
↓ Downsample + ResBlock
|
||
↓ [B, 1024, 4]
|
||
↓ Downsample + ResBlock
|
||
↓ [B, 2048, 2]
|
||
↓ Middle Block (with attention)
|
||
↓ Upsample + ResBlock
|
||
↓ [B, 1024, 4]
|
||
↓ Upsample + ResBlock
|
||
↓ [B, 512, 8]
|
||
↓ Upsample + ResBlock
|
||
↓ [B, 256, 16]
|
||
↓ Conv1D
|
||
输出: 预测噪声 [B, 16, 16]
|
||
```
|
||
|
||
|
||
### A.8 完整前向传播流程
|
||
|
||
基于实际代码,完整的前向传播流程如下:
|
||
|
||
```python
|
||
# 1. 条件编码 (一次性完成,可缓存)
|
||
cond_img_emb = clip_encoder(img) → resampler → [B, 16, 1024]
|
||
cond_text_emb = text_encoder(text) → [B, seq_len, 1024]
|
||
cond_state_emb = state_projector(state) + pos_emb → [B, T_obs, 1024]
|
||
cond_action_emb = action_projector(action) + pos_emb → [B, T_action, 1024]
|
||
cond_latent = vae.encode(img) → [B, 4, T, 40, 64]
|
||
|
||
# 2. 拼接条件
|
||
c_concat = [cond_latent] # 通道拼接
|
||
c_crossattn = [cond_text_emb, cond_img_emb, cond_state_emb, cond_action_emb]
|
||
c_crossattn = torch.cat(c_crossattn, dim=1) # [B, total_tokens, 1024]
|
||
|
||
# 3. DDIM采样循环 (ddim_steps 默认 50,实际由 --ddim_steps 控制)
|
||
x = torch.randn([B, 4, 16, 40, 64]) # 初始噪声
|
||
for step in range(ddim_steps):
|
||
# 3.1 时间步嵌入
|
||
t_emb = timestep_embedding(t, 320) → Linear → [B, 1280]
|
||
|
||
# 3.2 拼接输入
|
||
x_in = torch.cat([x, cond_latent], dim=1) # [B, 8, 16, 40, 64]
|
||
|
||
# 3.3 UNet前向传播 (核心瓶颈)
|
||
noise_pred = wma_model(x_in, t_emb, c_crossattn)
|
||
# 包含: 4个下采样阶段 + 中间块 + 3个上采样阶段
|
||
# 每个阶段: 2个ResBlock + SpatialTransformer + TemporalTransformer
|
||
|
||
# 3.4 DDIM更新
|
||
x = ddim_update(x, noise_pred, t, t_prev)
|
||
|
||
# 4. VAE解码
|
||
video = vae.decode(x) → [B, 3, 16, 320, 512]
|
||
```
|
||
|
||
|
||
### A.9 基于实际架构的优化建议更新
|
||
|
||
**我的理解**:
|
||
本次 profiling 显示采样阶段占据绝对主导地位:单次采样(50步)平均 35.58s,且每次迭代包含 action_generation 与 world_model_interaction 各一次采样。换句话说,任何“每步的细微改进”都会被 50 步和 2 阶段放大;因此最有效的优化要么减少步数,要么显著加速 UNet 前向。CUDA 时间主要集中在 Linear/GEMM(29.8%) 与 Convolution(13.9%),而 Attention 约 3.0%,这意味着算子层面优先考虑矩阵乘/卷积路径的优化收益更稳定。CPU 侧 `aten::copy_/to/_to_copy` 也明显,说明循环内的数据搬运仍有成本可省。
|
||
|
||
#### 优化点1: 采样步数与采样器 (最高优先级)
|
||
|
||
**依据**:
|
||
- 50步采样平均 35.58s (0.712s/步),减少步数带来近线性收益
|
||
- 单次迭代约 76.07s,其中采样约占 93%
|
||
|
||
**建议**:
|
||
- 在保证质量的前提下,将 `--ddim_steps` 从 50 降到 20-30
|
||
- 评估更快采样器(如 DPM-Solver++/UniPC)以减少步数
|
||
- 若使用 CFG,注意 `unconditional_guidance_scale > 1.0` 会使每步前向翻倍
|
||
|
||
#### 优化点2: GEMM/Conv 主导路径加速 (高优先级)
|
||
|
||
**依据**:
|
||
- CUDA 时间主力来自 Linear/GEMM 与 Convolution
|
||
|
||
**建议**:
|
||
- `torch.compile()` 仅包裹 UNet 主干以获得融合收益
|
||
- 启用混合精度 (`autocast`) + TF32 (`torch.backends.cuda.matmul.allow_tf32 = True`)
|
||
- 固定输入形状时开启 `torch.backends.cudnn.benchmark = True`
|
||
|
||
#### 优化点3: ResBlock融合 (中高优先级)
|
||
|
||
**实际瓶颈**:
|
||
- 每个DDIM步骤调用UNet一次
|
||
- UNet包含: 4个下采样阶段 + 1个中间块 + 3个上采样阶段 = 8个阶段
|
||
- 每个阶段有2个ResBlock
|
||
- 总计: 16个ResBlock × 50步 × 2次(阶段1+2) = **1600次ResBlock调用**
|
||
|
||
**融合机会**:
|
||
```python
|
||
# 当前: 6次kernel启动
|
||
h = group_norm(x) # kernel 1
|
||
h = silu(h) # kernel 2
|
||
h = conv2d(h) # kernel 3
|
||
h = group_norm(h) # kernel 4
|
||
h = silu(h) # kernel 5
|
||
h = conv2d(h) # kernel 6
|
||
out = x + h # kernel 7
|
||
|
||
# 优化后: 2-3次kernel启动
|
||
h = fused_norm_silu_conv(x) # kernel 1 (融合)
|
||
h = fused_norm_silu_conv(h) # kernel 2 (融合)
|
||
out = fused_residual_add(x, h) # kernel 3 (融合)
|
||
```
|
||
|
||
**预期收益**: 每个ResBlock节省50-60%的kernel启动开销
|
||
|
||
#### 优化点4: 注意力机制优化 (中优先级)
|
||
|
||
**实际配置**:
|
||
- SpatialTransformer: 在每个阶段的每个ResBlock后
|
||
- TemporalTransformer: 在每个阶段的每个ResBlock后
|
||
- 总计: 16个Spatial + 16个Temporal = **32个Transformer × 50步 × 2次 = 3200次注意力调用**
|
||
|
||
**理解**:
|
||
Attention 在算子占比中只有约 3%,不是当前主要瓶颈,但若未启用高效实现仍可获得稳定收益。
|
||
|
||
**优化方案**:
|
||
- 确认 xformers 已启用 (`XFORMERS_IS_AVAILBLE` 为 True)
|
||
- 无 xformers 时替换为 PyTorch SDPA:
|
||
```python
|
||
from torch.nn.functional import scaled_dot_product_attention
|
||
out = scaled_dot_product_attention(Q, K, V, is_causal=False)
|
||
```
|
||
|
||
#### 优化点5: 数据搬运与 CPU 开销 (中优先级)
|
||
|
||
**依据**:
|
||
- `aten::copy_` 与 `aten::to/_to_copy` 在 CPU 侧耗时突出
|
||
|
||
**建议**:
|
||
- 避免在 DDIM 循环内重复 `.to(device)` / `.float()` / `.half()`
|
||
- 将常量张量(如 timestep、sigma)提前放到 GPU
|
||
- 尽量减少临时张量创建与 `clone()`,尤其是 per-step 级别
|
||
|
||
#### 优化点6: VAE 解码 (低优先级)
|
||
|
||
**依据**:
|
||
- VAE 解码 0.57s/次,次于采样瓶颈
|
||
|
||
**建议**:
|
||
- 统一使用 `autocast` 解码
|
||
- 若可容忍轻微质量下降,可降低解码频率或分辨率
|
||
|