diff --git a/model_architecture_analysis.md b/model_architecture_analysis.md index 4114c2c..33f85a7 100644 --- a/model_architecture_analysis.md +++ b/model_architecture_analysis.md @@ -125,10 +125,11 @@ def ddim_sampling(self, cond, shape, x_T=None, ddim_steps=50, ...): return x ``` -**性能数据** (来自profiling): -- 单步去噪总耗时: 10.71s - 11.06s (22次调用) -- 模型前向: 325.30s (660次调用, 平均0.493s/次) -- DDIM更新: 0.21s (660次调用, 平均0.0003s/次) +**性能数据** (来自 profiling 报告,`--ddim_steps 50`): +- DDIM采样调用: 22次 (action_generation + world_model_interaction 各11次) +- 单次采样(50步)平均耗时: 35.58s (总计 782.70s) +- 平均每步耗时: ~0.712s (35.58s / 50) +- 当前 `unconditional_guidance_scale=1.0` 时每步 1 次 UNet 前向;开启 CFG 时每步 2 次前向 ### 3.2 DiffusionWrapper (条件路由器) @@ -205,14 +206,16 @@ z = model.encode_first_stage(img) video = model.decode_first_stage(samples) ``` -**性能数据**: -- VAE编码: 1.03s (22次, 平均0.047s/次) -- VAE解码: 15.53s (22次, 平均0.706s/次) -- 压缩比: 8×8 = 64倍空间压缩 +**性能数据**: +- VAE编码: 0.90s (22次, 平均0.041s/次) +- VAE解码: 12.44s (22次, 平均0.566s/次) +- 压缩比: 8×8 = 64倍空间压缩 **详细架构**: 见附录A.4 -### 3.5 条件编码器 +### 3.5 条件编码器 + +**性能说明**: 本次 profiling 未对各条件编码器单独计时,统一计入 `synthesis/conditioning_prep`,总计 2.92s (22次, 平均0.133s/次)。 #### 3.5.1 CLIP图像编码器 @@ -237,9 +240,7 @@ Resampler (图像投影器): **数据流**: 图像 [B, 3, H, W] → CLIP → [B, 1280] → Resampler → [B, 16, 1024] -**性能**: 0.71s (22次, 平均0.032s/次) - -#### 3.5.2 文本编码器 +#### 3.5.2 文本编码器 **代码位置**: [src/unifolm_wma/modules/encoders/condition.py](src/unifolm_wma/modules/encoders/condition.py) - `FrozenOpenCLIPEmbedder` @@ -253,9 +254,7 @@ FrozenOpenCLIPEmbedder: # 输出: [B, seq_len, 1024] ``` -**性能**: 0.13s (22次, 平均0.006s/次) - -#### 3.5.3 状态投影器 +#### 3.5.3 状态投影器 **代码位置**: [src/unifolm_wma/models/ddpms.py:2014-2026](src/unifolm_wma/models/ddpms.py) - `MLPProjector` @@ -273,9 +272,7 @@ class MLPProjector(nn.Module): **数据流**: 状态 [B, T_obs, 16] → MLPProjector → [B, T_obs, 1024] + agent_state_pos_emb -**性能**: 0.006s (22次, 平均0.0003s/次) - -#### 3.5.4 动作投影器 +#### 3.5.4 动作投影器 **代码位置**: [src/unifolm_wma/models/ddpms.py:2020-2024](src/unifolm_wma/models/ddpms.py) - `MLPProjector` @@ -288,47 +285,49 @@ self.agent_action_pos_emb = nn.Parameter(torch.randn(1, 16, 1024)) self.agent_state_pos_emb = nn.Parameter(torch.randn(1, n_obs_steps, 1024)) ``` -**性能**: 0.003s (22次, 平均0.0001s/次) - ---- - -## 4. 性能瓶颈分析 - -### 4.1 时间分布 (总计412.39s) - -根据性能分析数据,时间分布如下: - -| 阶段 | 总耗时 | 占比 | 说明 | -|------|--------|------|------| -| 阶段1: 生成动作 | 171.52s | 41.6% | DDIM采样30步 | -| 阶段2: 世界模型交互 | 171.65s | 41.6% | DDIM采样30步 | -| 模型加载 | 47.56s | 11.5% | 一次性开销 | -| 保存视频 | 13.91s | 3.4% | I/O操作 | -| 保存完整视频 | 7.22s | 1.8% | I/O操作 | -| 数据集加载 | 0.51s | 0.1% | 一次性开销 | - -### 4.2 DDIM采样详细分析 - -**DDIM采样是绝对瓶颈,占总时间的94.9%** - -``` -DDIM采样总耗时: 325.74s -├── 模型前向传播: 325.30s (99.86%) ← 核心瓶颈 -├── DDIM更新公式: 0.21s (0.06%) -└── Action/State调度: 0.13s (0.04%) -``` - -**每步耗时分析**: -- 30个去噪步骤,每步平均耗时: 10.86s -- 每步调用模型前向2次 (阶段1和阶段2各1次) -- 每次前向传播: ~0.493s - -### 4.3 瓶颈总结 - -**关键发现**: -1. **模型前向传播占99.86%的DDIM时间** - 这是优化的核心目标 -2. VAE解码占4.5%总时间 - 次要优化目标 -3. 其他操作(条件编码、DDIM更新)耗时可忽略 +--- + +## 4. 性能瓶颈分析 + +### 4.1 时间分布 (profiling 报告, 11次迭代 / 22次采样) + +说明: `profile_section` 存在嵌套,宏观统计的总和不是 wall time,以下以每段的 total/avg 为准。 + +| Section | Count | Total(s) | Avg(s) | 说明 | +|---------|-------|----------|--------|------| +| iteration_total | 11 | 836.79 | 76.07 | 单次迭代总耗时 | +| action_generation | 11 | 399.71 | 36.34 | 生成动作 (DDIM 50步) | +| world_model_interaction | 11 | 398.38 | 36.22 | 世界模型交互 (DDIM 50步) | +| synthesis/ddim_sampling | 22 | 782.70 | 35.58 | 单次采样 | +| synthesis/conditioning_prep | 22 | 2.92 | 0.13 | 条件编码汇总 | +| synthesis/decode_first_stage | 22 | 12.44 | 0.57 | VAE解码 | +| save_results | 11 | 38.67 | 3.52 | I/O保存 | +| model_loading/config | 1 | 49.77 | 49.77 | 一次性开销 | +| model_loading/checkpoint | 1 | 11.83 | 11.83 | 一次性开销 | +| model_to_cuda | 1 | 8.91 | 8.91 | 一次性开销 | + +### 4.2 DDIM采样详细分析 + +**DDIM采样是主要瓶颈** (基于 50 步采样): + +- 采样调用次数: 22 次 (11 次迭代 × 2 阶段) +- 采样总耗时: 782.70s,平均 35.58s/次 +- 平均每步耗时: ~0.712s (35.58s / 50) +- `unconditional_guidance_scale=1.0` 时每步 1 次 UNet 前向;开启 CFG 时每步 2 次前向 +- 在 `conditioning_prep + ddim_sampling + decode_first_stage` 中,ddim_sampling 占约 98% + +### 4.3 瓶颈总结 + +**关键发现**: +1. **DDIM采样占比最高** - 单次迭代平均 76.07s,其中采样约 71.15s (≈93%) +2. **CUDA算子时间主要集中在 Linear/GEMM(29.8%) 与 Convolution(13.9%)**;Attention 约 3.0% +3. **CPU侧 copy/to 仍明显** (`aten::copy_`, `aten::to/_to_copy` 在报告中耗时靠前) +4. VAE解码为次级瓶颈 (0.57s/次) + +### 4.4 GPU显存概览 + +- Peak allocated: 17890.50 MB +- Average allocated: 16129.98 MB --- @@ -404,7 +403,7 @@ self.out_layers = nn.Sequential( 2. `emb_layers` 的 `SiLU + Linear` 可融合 3. 残差加法可与下一层的GroupNorm融合 -**实际瓶颈**: 16个ResBlock × 30步 × 2次 = **960次ResBlock调用** +**实际瓶颈**: 16个ResBlock × 50步 × 2次 = **1600次ResBlock调用** **预期收益**: 每个ResBlock节省50-60%的kernel启动开销 @@ -415,7 +414,7 @@ self.out_layers = nn.Sequential( **实际配置**: - SpatialTransformer: 空间维度注意力 - TemporalTransformer: 时间维度注意力 -- 总计: 32个Transformer × 30步 × 2次 = **1920次注意力调用** +- 总计: 32个Transformer × 50步 × 2次 = **3200次注意力调用** **优化方案**: 使用 PyTorch 内置的 Flash Attention: @@ -432,7 +431,7 @@ out = scaled_dot_product_attention(Q, K, V, is_causal=False) **代码位置**: [src/unifolm_wma/models/autoencoder.py](src/unifolm_wma/models/autoencoder.py) -**当前性能**: 15.53s (22次调用, 平均0.706s/次) +**当前性能**: 12.44s (22次调用, 平均0.566s/次) **优化方案**: 1. **混合精度**: 使用FP16进行解码 @@ -719,9 +718,9 @@ with torch.cuda.amp.autocast(): ### 9.1 关键发现 -1. **模型前向传播占99.86%的DDIM采样时间** - 这是优化的绝对核心 -2. **30步DDIM采样占总时间的83%** - 减少步数或加速单步是关键 -3. **VAE解码占4.5%** - 次要优化目标 +1. **DDIM采样仍是主要瓶颈** - 单次采样(50步)平均 35.58s +2. **Linear/GEMM 与 Convolution 为主要 CUDA 时间来源** - Attention 占比相对较小 +3. **VAE解码为次级优化目标** - 0.57s/次 ### 9.2 优化优先级 @@ -740,14 +739,14 @@ with torch.cuda.amp.autocast(): ### 9.3 预期成果 -通过系统性优化,预期可以将推理时间从 **412s 降低到 140-200s**,实现 **2-3倍加速**。 +通过系统性优化,预期可获得 **1.5-3倍加速** (视采样步数与编译/混合精度策略而定)。 --- -**文档版本**: v1.1 -**创建日期**: 2026-01-17 -**最后更新**: 2026-01-17 -**更新内容**: 根据实际代码验证并修正了文件路径、行号、组件位置和实现细节 +**文档版本**: v1.2 +**创建日期**: 2026-01-17 +**最后更新**: 2026-01-18 +**更新内容**: 校准DDIM步数为50并替换为最新profiling数据 --- @@ -1121,9 +1120,9 @@ c_concat = [cond_latent] # 通道拼接 c_crossattn = [cond_text_emb, cond_img_emb, cond_state_emb, cond_action_emb] c_crossattn = torch.cat(c_crossattn, dim=1) # [B, total_tokens, 1024] -# 3. DDIM采样循环 (30步) -x = torch.randn([B, 4, 16, 40, 64]) # 初始噪声 -for step in range(30): +# 3. DDIM采样循环 (ddim_steps 默认 50,实际由 --ddim_steps 控制) +x = torch.randn([B, 4, 16, 40, 64]) # 初始噪声 +for step in range(ddim_steps): # 3.1 时间步嵌入 t_emb = timestep_embedding(t, 320) → Linear → [B, 1280] @@ -1143,62 +1142,93 @@ video = vae.decode(x) → [B, 3, 16, 320, 512] ``` -### A.9 基于实际架构的优化建议更新 - -#### 优化点1: ResBlock融合 (高优先级) - -**实际瓶颈**: -- 每个DDIM步骤调用UNet一次 -- UNet包含: 4个下采样阶段 + 1个中间块 + 3个上采样阶段 = 8个阶段 -- 每个阶段有2个ResBlock -- 总计: 16个ResBlock × 30步 × 2次(阶段1+2) = **960次ResBlock调用** - -**融合机会**: -```python -# 当前: 6次kernel启动 -h = group_norm(x) # kernel 1 -h = silu(h) # kernel 2 -h = conv2d(h) # kernel 3 -h = group_norm(h) # kernel 4 -h = silu(h) # kernel 5 -h = conv2d(h) # kernel 6 -out = x + h # kernel 7 - -# 优化后: 2-3次kernel启动 -h = fused_norm_silu_conv(x) # kernel 1 (融合) -h = fused_norm_silu_conv(h) # kernel 2 (融合) -out = fused_residual_add(x, h) # kernel 3 (融合) -``` - -**预期收益**: 每个ResBlock节省50-60%的kernel启动开销 - - -#### 优化点2: 注意力机制优化 (高优先级) - -**实际配置**: -- SpatialTransformer: 在每个阶段的每个ResBlock后 -- TemporalTransformer: 在每个阶段的每个ResBlock后 -- 总计: 16个Spatial + 16个Temporal = **32个Transformer × 30步 × 2次 = 1920次注意力调用** - -**当前实现已支持xformers**: -代码在 `attention.py:8-13` 检测 xformers 可用性: -```python -try: - import xformers - import xformers.ops - XFORMERS_IS_AVAILBLE = True -except: - XFORMERS_IS_AVAILBLE = False -``` - -当 xformers 可用时,`CrossAttention` 会自动使用 `efficient_forward` 方法 (attention.py:90-91)。 - -**进一步优化方案** (如果xformers不可用): -```python -# 使用 PyTorch 内置 Flash Attention -from torch.nn.functional import scaled_dot_product_attention -out = scaled_dot_product_attention(Q, K, V, is_causal=False) -``` - -**预期收益**: 如果xformers已启用,注意力层已经是优化的;否则使用Flash Attention可加速2-3倍 +### A.9 基于实际架构的优化建议更新 + +**我的理解**: +本次 profiling 显示采样阶段占据绝对主导地位:单次采样(50步)平均 35.58s,且每次迭代包含 action_generation 与 world_model_interaction 各一次采样。换句话说,任何“每步的细微改进”都会被 50 步和 2 阶段放大;因此最有效的优化要么减少步数,要么显著加速 UNet 前向。CUDA 时间主要集中在 Linear/GEMM(29.8%) 与 Convolution(13.9%),而 Attention 约 3.0%,这意味着算子层面优先考虑矩阵乘/卷积路径的优化收益更稳定。CPU 侧 `aten::copy_/to/_to_copy` 也明显,说明循环内的数据搬运仍有成本可省。 + +#### 优化点1: 采样步数与采样器 (最高优先级) + +**依据**: +- 50步采样平均 35.58s (0.712s/步),减少步数带来近线性收益 +- 单次迭代约 76.07s,其中采样约占 93% + +**建议**: +- 在保证质量的前提下,将 `--ddim_steps` 从 50 降到 20-30 +- 评估更快采样器(如 DPM-Solver++/UniPC)以减少步数 +- 若使用 CFG,注意 `unconditional_guidance_scale > 1.0` 会使每步前向翻倍 + +#### 优化点2: GEMM/Conv 主导路径加速 (高优先级) + +**依据**: +- CUDA 时间主力来自 Linear/GEMM 与 Convolution + +**建议**: +- `torch.compile()` 仅包裹 UNet 主干以获得融合收益 +- 启用混合精度 (`autocast`) + TF32 (`torch.backends.cuda.matmul.allow_tf32 = True`) +- 固定输入形状时开启 `torch.backends.cudnn.benchmark = True` + +#### 优化点3: ResBlock融合 (中高优先级) + +**实际瓶颈**: +- 每个DDIM步骤调用UNet一次 +- UNet包含: 4个下采样阶段 + 1个中间块 + 3个上采样阶段 = 8个阶段 +- 每个阶段有2个ResBlock +- 总计: 16个ResBlock × 50步 × 2次(阶段1+2) = **1600次ResBlock调用** + +**融合机会**: +```python +# 当前: 6次kernel启动 +h = group_norm(x) # kernel 1 +h = silu(h) # kernel 2 +h = conv2d(h) # kernel 3 +h = group_norm(h) # kernel 4 +h = silu(h) # kernel 5 +h = conv2d(h) # kernel 6 +out = x + h # kernel 7 + +# 优化后: 2-3次kernel启动 +h = fused_norm_silu_conv(x) # kernel 1 (融合) +h = fused_norm_silu_conv(h) # kernel 2 (融合) +out = fused_residual_add(x, h) # kernel 3 (融合) +``` + +**预期收益**: 每个ResBlock节省50-60%的kernel启动开销 + +#### 优化点4: 注意力机制优化 (中优先级) + +**实际配置**: +- SpatialTransformer: 在每个阶段的每个ResBlock后 +- TemporalTransformer: 在每个阶段的每个ResBlock后 +- 总计: 16个Spatial + 16个Temporal = **32个Transformer × 50步 × 2次 = 3200次注意力调用** + +**理解**: +Attention 在算子占比中只有约 3%,不是当前主要瓶颈,但若未启用高效实现仍可获得稳定收益。 + +**优化方案**: +- 确认 xformers 已启用 (`XFORMERS_IS_AVAILBLE` 为 True) +- 无 xformers 时替换为 PyTorch SDPA: +```python +from torch.nn.functional import scaled_dot_product_attention +out = scaled_dot_product_attention(Q, K, V, is_causal=False) +``` + +#### 优化点5: 数据搬运与 CPU 开销 (中优先级) + +**依据**: +- `aten::copy_` 与 `aten::to/_to_copy` 在 CPU 侧耗时突出 + +**建议**: +- 避免在 DDIM 循环内重复 `.to(device)` / `.float()` / `.half()` +- 将常量张量(如 timestep、sigma)提前放到 GPU +- 尽量减少临时张量创建与 `clone()`,尤其是 per-step 级别 + +#### 优化点6: VAE 解码 (低优先级) + +**依据**: +- VAE 解码 0.57s/次,次于采样瓶颈 + +**建议**: +- 统一使用 `autocast` 解码 +- 若可容忍轻微质量下降,可降低解码频率或分辨率 diff --git a/scripts/evaluation/world_model_interaction.py b/scripts/evaluation/world_model_interaction.py index 07dc8ef..73e2baa 100644 --- a/scripts/evaluation/world_model_interaction.py +++ b/scripts/evaluation/world_model_interaction.py @@ -441,7 +441,7 @@ def get_filelist(data_dir: str, postfixes: list[str]) -> list[str]: return file_list -def load_model_checkpoint(model: nn.Module, ckpt: str) -> nn.Module: +def load_model_checkpoint(model: nn.Module, ckpt: str) -> nn.Module: """Load model weights from checkpoint file. Args: @@ -472,11 +472,43 @@ def load_model_checkpoint(model: nn.Module, ckpt: str) -> nn.Module: for key in state_dict['module'].keys(): new_pl_sd[key[16:]] = state_dict['module'][key] model.load_state_dict(new_pl_sd) - print('>>> model checkpoint loaded.') - return model - - -def is_inferenced(save_dir: str, filename: str) -> bool: + print('>>> model checkpoint loaded.') + return model + + +def _module_param_dtype(module: nn.Module | None) -> str: + if module is None: + return "None" + for param in module.parameters(): + return str(param.dtype) + return "no_params" + + +def log_inference_precision(model: nn.Module) -> None: + try: + param = next(model.parameters()) + device = str(param.device) + model_dtype = str(param.dtype) + except StopIteration: + device = "unknown" + model_dtype = "no_params" + + print(f">>> inference precision: model={model_dtype}, device={device}") + for attr in [ + "model", "first_stage_model", "cond_stage_model", "embedder", + "image_proj_model" + ]: + if hasattr(model, attr): + submodule = getattr(model, attr) + print(f">>> {attr} param dtype: {_module_param_dtype(submodule)}") + + print( + ">>> autocast gpu dtype default: " + f"{torch.get_autocast_gpu_dtype()} " + f"(enabled={torch.is_autocast_enabled()})") + + +def is_inferenced(save_dir: str, filename: str) -> bool: """Check if a given filename has already been processed and saved. Args: @@ -853,11 +885,13 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None: data.setup() print(">>> Dataset is successfully loaded ...") - with profiler.profile_section("model_to_cuda"): - model = model.cuda(gpu_no) - device = get_device_from_parameters(model) - - profiler.record_memory("after_model_load") + with profiler.profile_section("model_to_cuda"): + model = model.cuda(gpu_no) + device = get_device_from_parameters(model) + + log_inference_precision(model) + + profiler.record_memory("after_model_load") # Run over data assert (args.height % 16 == 0) and ( diff --git a/src/unifolm_wma/modules/encoders/condition.py b/src/unifolm_wma/modules/encoders/condition.py index 44f0fdc..4e23ae4 100644 --- a/src/unifolm_wma/modules/encoders/condition.py +++ b/src/unifolm_wma/modules/encoders/condition.py @@ -334,6 +334,15 @@ class FrozenOpenCLIPImageEmbedder(AbstractEncoder): @autocast def forward(self, image, no_dropout=False): + if not hasattr(self, "_printed_autocast_info"): + print( + ">>> 图像编码 autocast:", + torch.is_autocast_enabled(), + torch.get_autocast_gpu_dtype(), + "输入dtype:", + image.dtype, + ) + self._printed_autocast_info = True z = self.encode_with_vision_transformer(image) if self.ucg_rate > 0. and not no_dropout: z = torch.bernoulli( @@ -407,6 +416,15 @@ class FrozenOpenCLIPImageEmbedderV2(AbstractEncoder): def forward(self, image, no_dropout=False): ## image: b c h w + if not hasattr(self, "_printed_autocast_info"): + print( + ">>> 图像编码V2 autocast:", + torch.is_autocast_enabled(), + torch.get_autocast_gpu_dtype(), + "输入dtype:", + image.dtype, + ) + self._printed_autocast_info = True z = self.encode_with_vision_transformer(image) return z diff --git a/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768704535.node-0.193164.0 b/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768704535.node-0.193164.0 new file mode 100644 index 0000000..1a2c002 Binary files /dev/null and b/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768704535.node-0.193164.0 differ diff --git a/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768704559.node-0.193448.0 b/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768704559.node-0.193448.0 new file mode 100644 index 0000000..8326082 Binary files /dev/null and b/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768704559.node-0.193448.0 differ diff --git a/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768704797.node-0.194425.0 b/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768704797.node-0.194425.0 new file mode 100644 index 0000000..e9312da Binary files /dev/null and b/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768704797.node-0.194425.0 differ diff --git a/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768705207.node-0.197016.0 b/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768705207.node-0.197016.0 new file mode 100644 index 0000000..04bbace Binary files /dev/null and b/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768705207.node-0.197016.0 differ diff --git a/usefal.sh b/usefal.sh deleted file mode 100644 index 3e28604..0000000 --- a/usefal.sh +++ /dev/null @@ -1 +0,0 @@ -python3 psnr_score_for_challenge.py --gt_video unitree_g1_pack_camera/case1/unitree_g1_pack_camera_case1.mp4 --pred_video unitree_g1_pack_camera/case1/output/inference/0_full_fs6.mp4 --output_file unitree_g1_pack_camera/case1/psnr_result.json \ No newline at end of file diff --git a/useful.sh b/useful.sh new file mode 100644 index 0000000..a466ad5 --- /dev/null +++ b/useful.sh @@ -0,0 +1,32 @@ +python3 psnr_score_for_challenge.py --gt_video unitree_g1_pack_camera/case1/unitree_g1_pack_camera_case1.mp4 --pred_video unitree_g1_pack_camera/case1/output/inference/0_full_fs6.mp4 --output_file unitree_g1_pack_camera/case1/psnr_result.json +采样阶段(synthesis/ddim_sampling)+ UNet 前向几乎吃掉全部时间,外加明显的 CPU 侧 aten::to/_to_copy 和 aten::copy_ 开销;整体优化优先级还是“减少采样步数 + 加速每步前向 + 降低无谓拷贝”。下 + 面是更针对这份 profile 的思路: + + - 优先级1:减少采样步数/换更快采样器 + - 把 DDIM 30 步降到 10–20 步,或改用 DPM-Solver++/UniPC;这往往是 1.5–3× 的最直接收益。采样逻辑在 src/unifolm_wma/models/samplers/ddim.py,入口在 scripts/evaluation/ + world_model_interaction.py。 + - 如允许训练侧投入,可做蒸馏(LCM/Consistency)让 4–8 步也可用。 + - 优先级1:每步前向加速(编译 + AMP + TF32) + - torch.compile 只包 diffusion_model,见 scripts/evaluation/world_model_interaction.py(你文档里也已写)。 + - 推理包一层 autocast(fp16/bf16)并开启 TF32: + torch.backends.cuda.matmul.allow_tf32 = True + torch.set_float32_matmul_precision("high")。 + - torch.backends.cudnn.benchmark = True 在固定 shape 下很有效。 + - 优先级1:消除循环内的 to()/copy/clone + - profile 里 aten::_to_copy CPU 时间很高,建议逐步排查是否在 DDIM loop 或条件准备里重复 .to(device) / .float() / .half()。 + - 把常量(timesteps/sigmas/ts 等)提前放 GPU,避免每步创建;避免不必要的 clone()。 + - 重点排查 src/unifolm_wma/models/samplers/ddim.py 与 scripts/evaluation/world_model_interaction.py 的数据准备段。 + - 优先级2:注意力实现确认 + - attention 只占 3% 左右,但如果没启用 xformers/SDPA,仍有收益空间。检查 src/unifolm_wma/modules/attention.py 的 XFORMERS_IS_AVAILBLE。 + - 无 xformers 时可改用 scaled_dot_product_attention(Flash Attention 路径)。 + - 优先级2:VAE 解码 & 保存 I/O + - synthesis/decode_first_stage 仍是秒级,建议 autocast + 可能的 torch.compile。位置在 src/unifolm_wma/models/autoencoder.py。 + - save_results 约 38s:如果只是评测,考虑降低保存频率/分辨率或异步写盘。 + - 优先级3:结构性减负 + - 降低 temporal_length、输入分辨率或 model_channels 会线性降低 compute(配置在 configs/inference/world_model_interaction.yaml)。 + - 如果 action_generation 与 world_model_interaction 共享条件,可以缓存 CLIP/VAE 编码,避免重复计算(model_architecture_analysis.md 的条件编码流程已说明)。 + + 如果你希望我直接落地改动,推荐顺序: + + 1. torch.compile + AMP + TF32 + cudnn.benchmark + 2. 排查 .to()/copy/clone 的重复位置并移出循环 + 3. 若需要更大幅度,再换采样器/降步数 \ No newline at end of file