我建议优先做这 4 类,都是跨数据集成立的: 1. 压 rollout 内环实现 见 jepa.py:127。现在每步都在做 action_encoder、切片、torch.cat、小规 模 predict 调用,这种碎片化实现对任何任务都亏。 通用改法: - 整条 action_sequence 一次性做 action_encoder - emb_hist / act_emb_hist 改成预分配 buffer - 循环里只做索引覆盖或 copy_ - 去掉循环内 torch.cat 2. 减少热路径里的搬运和同步 profile 里 aten::copy_ 很重,这不是 TwoRoom 特有问题。重点看 jepa.py:67 和 jepa.py:186。 通用目标: - 模型侧张量尽量全程留在 GPU - 避免热路径反复 .to(device) / 隐式 layout 修复 - 到必须和环境交互的边界再一次性转 CPU / numpy - 确保进入 predictor 的张量是 contiguous 的,少触发隐式 copy 3. 把编译成本移出正式计时 现在 torch.compile 默认开在 predictor,见 eval.py:70。102s -> 45s 很 像首轮编译预热。 通用做法: - 在正式 start_time 前做一次 dummy predict 或 dummy rollout - 保留只编译 predictor/predict,不要编译整个 solver 4. 减少临时对象和 shape bookkeeping 这是所有任务都会受益的。 重点看: - jepa.py:100 到 jepa.py:106 - jepa.py:143 到 jepa.py:148 方向是: - 能循环外做的 reshape,不放循环里 - 能原地更新,不新建张量 - 少做 dict 字段增删和中间容器组装 不建议优先做的通用性较差方案: - 调 TwoRoom 专属 cache 规则 - 改数据集采样逻辑 - 按小数据集特点缩短 horizon - 直接改 CEM 超参当“优化” 如果你要我直接开始改,我建议第一批只做两件事: - 重写 jepa.py:127 这段 rollout,去掉循环内 action_encoder + cat - 在 eval.py:306 前加 compile warmup