Files
lewm/sth.md
2026-04-09 11:11:07 +00:00

192 lines
5.9 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
1. 压 rollout 内环
这条最通用,而且基本不改算法语义,只是把实现做对。
在 jepa.py:129 这段里,当前问题是:
- 循环里每步都 action_encoder(next_act),见 jepa.py:159
- history 每步用 torch.cat 重建,见 jepa.py:155 和 jepa.py:162
- 每步都走一次很短的 predict()host 调度比例很高
通用改法:
- 整条 action_sequence 一次性做 action_encoder
- emb_hist / act_emb_hist 改成预分配 buffer
- 用 ring buffer 或 index rotate 更新历史
- 循环里只做 copy_ / 索引覆盖,不做 cat
这个优化对任何数据集都成立因为它优化的是“rolling inference 实现方
式”,不是任务参数。
2. 用 torch.inference_mode()
你现在在 eval.py:242 这里只用了 autocast没有 inference_mode()。
建议推理主路径外层直接包:
with torch.inference_mode():
with inference_ctx:
...
这是纯通用优化,所有数据集都受益。
3. 只编译 predictor / predict不要编译整个 solver
当前热点是大量小 predict() 调用,不是整条 eval graph。
通用建议:
- 先只编译 self.predictor
- 或只编译 JEPA.predict()
- 模式优先试 reduce-overhead
不要先编译整个 WorldModelPolicy 或 CEM solver那通常图不稳定泛化收益
反而差。
4. 减少循环里的张量形状重排和临时对象
这也是实现层通用优化。
可以继续查:
- rearrange 是否能前移到循环外
- 是否有重复的 slice/view 触发隐式拷贝
- pred_proj(rearrange(...)) 这类 reshape 往返是否能合并
这类优化对所有任务都有效,因为是在降 Python 和 tensor bookkeeping 成
本。
5. 再考虑结构级优化,但放后面
比如 predictor 深度、MLP 宽度、heads 数量。这也通用,但已经开始碰模型容
量和精度,不该是第一刀。
不建议优先做的
这些更偏任务/数据集相关,不算你要的“泛用优化”:
- 先调 num_samples/topk/n_steps
- 先缩 horizon
- 先按 tworoom 特性做 shortcut
- 先针对某个 dataset 做 cache 规则
一句话判断
你现在最像是“算法没错,但 rollout 实现过于碎片化”,所以第一优先级应该
是:
一次性 action encode + 预分配历史 buffer + 去掉循环内 torch.cat +
inference_mode + compile predictor
如果你要,我下一步就直接改 jepa.py 做这套通用优化,不碰任何数据集特化逻
辑。
除开 CEM solve 本体,剩下这些杂项可以这样优化。
最高优先级
1. 保证传给环境的是 numpy不要让 Gym 代转
你日志已经说明 env step() 收到了 torch.Tensor。这会带来拷贝、同步、
checker 额外开销。
做法:
- 在 policy 输出动作、准备喂给 env 的那一层,显式转成
action.detach().cpu().numpy()
- 最好一次性转好,别在 env 内部或 wrapper 内隐式转换
收益:
- 去掉 Gym warning
- 减少同步和类型检查开销
- 通常是最直接的非模型提速点
2. 关掉 Gym passive checker
这些 warning 本身就说明 checker 在持续检查类型和空间匹配。
做法:
- 尽量用禁 checker 的构造方式
- 或在你自己的 env wrapper 里保证输入输出符合 Gym 预期,避免它每步检查
收益:
- 每步少一层 Python 校验
- 对长 episode / 多 episode 累积明显
中优先级
3. 把预处理前移,避免每步重复做能缓存的东西
如果 goal、初始条件、某些 dataset 字段在 episode 里不变,就不要每次
都重新组织。
做法:
- goal 相关 embedding 已经有缓存,继续扩展到更多静态字段
- 固定的 callables 参数尽量预解析
- 能在 episode 开头准备好的,不要放在 step 循环里
收益:
- 降低 Python dict 操作和小张量处理开销
4. 避免频繁 CPU/GPU 来回切
如果模型在 GPU但环境在 CPU就要非常小心中间格式。
做法:
- 模型侧尽量连续留在 GPU
- 到真正 env step 前再一次性转 CPU numpy
- 不要中间反复 .cpu() / .to(device) / np.array(...)
收益:
- 减少隐式同步
- 稳定延迟
5. 缩减 Python 层对象操作
dict 组装、字段拷贝、wrapper 嵌套太多时,端到端会慢。
做法:
- 关键热路径里少做深拷贝
- 少重复构造新的 info / obs 容器
- 固定结构优先原地更新
收益:
- 对小步高频调用路径有效
如果你要继续压评测时间
6. 降低日志和 warning 输出
频繁 warning 会拖慢,而且污染 timing。
做法:
- 修掉类型不匹配后 warning 自然消失
- 非必要的 print 尤其是 step 内 print 要去掉
7. 针对环境 step 做批量化检查
如果 num_envs=50尽量确认 env wrapper 没有在内部退化成逐环境 Python
for-loop。
做法:
- 查 world.evaluate_from_dataset() 到 env step() 之间是不是 batch 接口
- 如果 batch env 里还有逐个 env 转换/检查,尽量前移或向量化
收益:
- 这类经常能解释“为什么 solver 时间之外还有很多时间”
8. 把 callables 的执行成本单独看
你这里有:
- _set_state
- _set_goal_state
- 确认它们只在必要时执行
- 能批量设置就别逐条 Python 调
2. 消掉 Gym warning
3. 单独量 env.step 总时间
4. 检查是否有反复 CPU/GPU 转换
5. 再看 wrapper / callable / obs 组装
一句话总结
剩下的杂项优化,核心不是“再多上几张卡”,而是:
- 去掉隐式类型转换
- 去掉多余检查
- 去掉重复数据整理
- 减少 CPU/GPU 往返
- 减少 Python 高频小开销
如果你要,我下一步可以直接帮你定位“动作是在哪一层以 torch.Tensor 传进
env 的”,给你指出具体应该改哪个函数。