From 995cd8cfec4aba4fae519681849c455ba16db14c Mon Sep 17 00:00:00 2001 From: qihuanye Date: Thu, 9 Apr 2026 11:57:09 +0000 Subject: [PATCH] =?UTF-8?q?=20=E4=BC=98=E5=8C=96=20jepa.py=20=E4=B8=AD?= =?UTF-8?q?=E9=80=9A=E7=94=A8=20rollout=20=E7=83=AD=E8=B7=AF=E5=BE=84?= =?UTF-8?q?=EF=BC=9A=E6=89=B9=E9=87=8F=E9=A2=84=E7=BC=96=E7=A0=81=E5=8A=A8?= =?UTF-8?q?=20=20=20=E4=BD=9C=E3=80=81=E7=A7=BB=E9=99=A4=E5=BE=AA=E7=8E=AF?= =?UTF-8?q?=E5=86=85=20=20=20torch.cat=EF=BC=8C=E5=B9=B6=E4=B8=BA=20histor?= =?UTF-8?q?y=5Fsize=3D=3D1=20=E4=B8=8E=E7=8E=AF=E5=BD=A2=E7=BC=93=E5=86=B2?= =?UTF-8?q?=E5=8C=BA=E6=9B=B4=E6=96=B0=20=20=20=E6=B7=BB=E5=8A=A0=E6=9B=B4?= =?UTF-8?q?=E8=BD=BB=E9=87=8F=E5=AE=9E=E7=8E=B0=EF=BC=9B=20=E6=94=B6?= =?UTF-8?q?=E7=9B=8A=E4=B8=8D=E5=A4=A7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- jepa.py | 92 +++++++-- sth.md | 216 ++++---------------- tworoom_results.txt | 486 ++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 597 insertions(+), 197 deletions(-) diff --git a/jepa.py b/jepa.py index 6c0b947..4ecc2b7 100644 --- a/jepa.py +++ b/jepa.py @@ -133,35 +133,89 @@ class JEPA(nn.Module): """ with torch.profiler.record_function("lewm.rollout"): assert "pixels" in info, "pixels not in info_dict" + if history_size < 1: + raise ValueError("history_size must be >= 1") + H = info["pixels"].size(2) B, S, T = action_sequence.shape[:3] - act_0, act_future = torch.split(action_sequence, [H, T - H], dim=2) + if T < H: + raise ValueError( + f"action_sequence horizon ({T}) must be >= history length ({H})" + ) # Cache the encoded initial state across solver iterations. init_emb = self._get_cached_init_emb(info) HS = history_size - emb_hist = init_emb.unsqueeze(1).expand(B, S, -1, -1) - emb_hist = emb_hist[..., -HS:, :].reshape(B * S, min(HS, init_emb.size(1)), -1) + hist_len = min(HS, init_emb.size(1), H) + if hist_len < 1: + raise ValueError("rollout requires at least one history step") - act_hist = act_0[..., -HS:, :].reshape(B * S, min(HS, act_0.size(2)), -1) - act_emb_hist = self.action_encoder(act_hist) - act_future = act_future.reshape(B * S, act_future.size(2), -1) + init_hist = init_emb[:, -hist_len:] + init_hist = init_hist.unsqueeze(1).expand(-1, S, -1, -1) + init_hist = init_hist.reshape(B * S, hist_len, init_hist.size(-1)) - for t in range(act_future.size(1)): - pred_emb = self.predict(emb_hist[:, -HS:], act_emb_hist[:, -HS:])[:, -1:] - if HS > 1: - emb_hist = torch.cat([emb_hist[:, -HS + 1 :], pred_emb], dim=1) + flat_actions = action_sequence.reshape(B * S, T, -1) + action_emb = self.action_encoder(flat_actions) + act_hist = action_emb[:, H - hist_len : H] + act_future = action_emb[:, H:] + + if HS == 1: + emb_hist = init_hist[:, -1:] + act_emb_hist = act_hist[:, -1:] + + for t in range(act_future.size(1)): + emb_hist = self.predict(emb_hist, act_emb_hist)[:, -1:] + act_emb_hist = act_future[:, t : t + 1] + + pred_rollout = self.predict(emb_hist, act_emb_hist)[:, -1:] + else: + emb_hist = init_hist.new_empty((B * S, HS, init_hist.size(-1))) + act_emb_hist = action_emb.new_empty((B * S, HS, action_emb.size(-1))) + emb_hist[:, :hist_len].copy_(init_hist) + act_emb_hist[:, :hist_len].copy_(act_hist) + + history_order = torch.stack( + [ + (torch.arange(HS, device=action_emb.device) + offset) % HS + for offset in range(HS) + ] + ) + filled = hist_len + next_slot = hist_len % HS + + for t in range(act_future.size(1)): + if filled < HS: + emb_view = emb_hist[:, :filled] + act_view = act_emb_hist[:, :filled] + elif next_slot == 0: + emb_view = emb_hist + act_view = act_emb_hist + else: + order = history_order[next_slot] + emb_view = emb_hist.index_select(1, order) + act_view = act_emb_hist.index_select(1, order) + + pred_emb = self.predict(emb_view, act_view)[:, -1:] + next_act_emb = act_future[:, t : t + 1] + emb_hist[:, next_slot : next_slot + 1].copy_(pred_emb) + act_emb_hist[:, next_slot : next_slot + 1].copy_(next_act_emb) + + if filled < HS: + filled += 1 + next_slot = (next_slot + 1) % HS + + if filled < HS: + emb_view = emb_hist[:, :filled] + act_view = act_emb_hist[:, :filled] + elif next_slot == 0: + emb_view = emb_hist + act_view = act_emb_hist else: - emb_hist = pred_emb + order = history_order[next_slot] + emb_view = emb_hist.index_select(1, order) + act_view = act_emb_hist.index_select(1, order) - next_act = act_future[:, t : t + 1, :] - next_act_emb = self.action_encoder(next_act) - if HS > 1: - act_emb_hist = torch.cat([act_emb_hist[:, -HS + 1 :], next_act_emb], dim=1) - else: - act_emb_hist = next_act_emb - - pred_rollout = self.predict(emb_hist[:, -HS:], act_emb_hist[:, -HS:])[:, -1:] + pred_rollout = self.predict(emb_view, act_view)[:, -1:] info["predicted_emb"] = pred_rollout.reshape(B, S, *pred_rollout.shape[1:]) return info diff --git a/sth.md b/sth.md index 7ce0538..0221f1b 100644 --- a/sth.md +++ b/sth.md @@ -1,192 +1,52 @@ -1. 压 rollout 内环 - 这条最通用,而且基本不改算法语义,只是把实现做对。 +我建议优先做这 4 类,都是跨数据集成立的: - 在 jepa.py:129 这段里,当前问题是: - - - 循环里每步都 action_encoder(next_act),见 jepa.py:159 - - history 每步用 torch.cat 重建,见 jepa.py:155 和 jepa.py:162 - - 每步都走一次很短的 predict(),host 调度比例很高 - - 通用改法: + 1. 压 rollout 内环实现 + 见 jepa.py:127。现在每步都在做 action_encoder、切片、torch.cat、小规 + 模 predict 调用,这种碎片化实现对任何任务都亏。 + 通用改法: - 整条 action_sequence 一次性做 action_encoder - emb_hist / act_emb_hist 改成预分配 buffer - - 用 ring buffer 或 index rotate 更新历史 - - 循环里只做 copy_ / 索引覆盖,不做 cat + - 循环里只做索引覆盖或 copy_ + - 去掉循环内 torch.cat - 这个优化对任何数据集都成立,因为它优化的是“rolling inference 实现方 - 式”,不是任务参数。 + 2. 减少热路径里的搬运和同步 + profile 里 aten::copy_ 很重,这不是 TwoRoom 特有问题。重点看 + jepa.py:67 和 jepa.py:186。 + 通用目标: - 2. 用 torch.inference_mode() - 你现在在 eval.py:242 这里只用了 autocast,没有 inference_mode()。 + - 模型侧张量尽量全程留在 GPU + - 避免热路径反复 .to(device) / 隐式 layout 修复 + - 到必须和环境交互的边界再一次性转 CPU / numpy + - 确保进入 predictor 的张量是 contiguous 的,少触发隐式 copy - 建议推理主路径外层直接包: + 3. 把编译成本移出正式计时 + 现在 torch.compile 默认开在 predictor,见 eval.py:70。102s -> 45s 很 + 像首轮编译预热。 + 通用做法: - with torch.inference_mode(): - with inference_ctx: - ... + - 在正式 start_time 前做一次 dummy predict 或 dummy rollout + - 保留只编译 predictor/predict,不要编译整个 solver - 这是纯通用优化,所有数据集都受益。 + 4. 减少临时对象和 shape bookkeeping + 这是所有任务都会受益的。 + 重点看: - 3. 只编译 predictor / predict,不要编译整个 solver - 当前热点是大量小 predict() 调用,不是整条 eval graph。 + - jepa.py:100 到 jepa.py:106 + - jepa.py:143 到 jepa.py:148 + 方向是: + - 能循环外做的 reshape,不放循环里 + - 能原地更新,不新建张量 + - 少做 dict 字段增删和中间容器组装 - 通用建议: + 不建议优先做的通用性较差方案: - - 先只编译 self.predictor - - 或只编译 JEPA.predict() - - 模式优先试 reduce-overhead + - 调 TwoRoom 专属 cache 规则 + - 改数据集采样逻辑 + - 按小数据集特点缩短 horizon + - 直接改 CEM 超参当“优化” - 不要先编译整个 WorldModelPolicy 或 CEM solver;那通常图不稳定,泛化收益 - 反而差。 + 如果你要我直接开始改,我建议第一批只做两件事: - 4. 减少循环里的张量形状重排和临时对象 - 这也是实现层通用优化。 - - 可以继续查: - - - rearrange 是否能前移到循环外 - - 是否有重复的 slice/view 触发隐式拷贝 - - pred_proj(rearrange(...)) 这类 reshape 往返是否能合并 - - 这类优化对所有任务都有效,因为是在降 Python 和 tensor bookkeeping 成 - 本。 - - 5. 再考虑结构级优化,但放后面 - 比如 predictor 深度、MLP 宽度、heads 数量。这也通用,但已经开始碰模型容 - 量和精度,不该是第一刀。 - - 不建议优先做的 - 这些更偏任务/数据集相关,不算你要的“泛用优化”: - - - 先调 num_samples/topk/n_steps - - 先缩 horizon - - 先按 tworoom 特性做 shortcut - - 先针对某个 dataset 做 cache 规则 - - 一句话判断 - 你现在最像是“算法没错,但 rollout 实现过于碎片化”,所以第一优先级应该 - 是: - 一次性 action encode + 预分配历史 buffer + 去掉循环内 torch.cat + - inference_mode + compile predictor - - 如果你要,我下一步就直接改 jepa.py 做这套通用优化,不碰任何数据集特化逻 - 辑。 - - 除开 CEM solve 本体,剩下这些杂项可以这样优化。 - - 最高优先级 - - 1. 保证传给环境的是 numpy,不要让 Gym 代转 - 你日志已经说明 env step() 收到了 torch.Tensor。这会带来拷贝、同步、 - checker 额外开销。 - 做法: - - - 在 policy 输出动作、准备喂给 env 的那一层,显式转成 - action.detach().cpu().numpy() - - 最好一次性转好,别在 env 内部或 wrapper 内隐式转换 - - 收益: - - - 去掉 Gym warning - - 减少同步和类型检查开销 - - 通常是最直接的非模型提速点 - - 2. 关掉 Gym passive checker - 这些 warning 本身就说明 checker 在持续检查类型和空间匹配。 - 做法: - - - 尽量用禁 checker 的构造方式 - - 或在你自己的 env wrapper 里保证输入输出符合 Gym 预期,避免它每步检查 - - 收益: - - - 每步少一层 Python 校验 - - 对长 episode / 多 episode 累积明显 - - 中优先级 - - 3. 把预处理前移,避免每步重复做能缓存的东西 - 如果 goal、初始条件、某些 dataset 字段在 episode 里不变,就不要每次 - 都重新组织。 - 做法: - - - goal 相关 embedding 已经有缓存,继续扩展到更多静态字段 - - 固定的 callables 参数尽量预解析 - - 能在 episode 开头准备好的,不要放在 step 循环里 - - 收益: - - - 降低 Python dict 操作和小张量处理开销 - - 4. 避免频繁 CPU/GPU 来回切 - 如果模型在 GPU,但环境在 CPU,就要非常小心中间格式。 - 做法: - - - 模型侧尽量连续留在 GPU - - 到真正 env step 前再一次性转 CPU numpy - - 不要中间反复 .cpu() / .to(device) / np.array(...) - - 收益: - - - 减少隐式同步 - - 稳定延迟 - - 5. 缩减 Python 层对象操作 - dict 组装、字段拷贝、wrapper 嵌套太多时,端到端会慢。 - 做法: - - - 关键热路径里少做深拷贝 - - 少重复构造新的 info / obs 容器 - - 固定结构优先原地更新 - - 收益: - - - 对小步高频调用路径有效 - - 如果你要继续压评测时间 - - 6. 降低日志和 warning 输出 - 频繁 warning 会拖慢,而且污染 timing。 - 做法: - - - 修掉类型不匹配后 warning 自然消失 - - 非必要的 print 尤其是 step 内 print 要去掉 - - 7. 针对环境 step 做批量化检查 - 如果 num_envs=50,尽量确认 env wrapper 没有在内部退化成逐环境 Python - for-loop。 - 做法: - - - 查 world.evaluate_from_dataset() 到 env step() 之间是不是 batch 接口 - - 如果 batch env 里还有逐个 env 转换/检查,尽量前移或向量化 - - 收益: - - - 这类经常能解释“为什么 solver 时间之外还有很多时间” - - 8. 把 callables 的执行成本单独看 - 你这里有: - - - _set_state - - _set_goal_state - - - - 确认它们只在必要时执行 - - 能批量设置就别逐条 Python 调 - 2. 消掉 Gym warning - 3. 单独量 env.step 总时间 - 4. 检查是否有反复 CPU/GPU 转换 - 5. 再看 wrapper / callable / obs 组装 - - 一句话总结 - 剩下的杂项优化,核心不是“再多上几张卡”,而是: - - - 去掉隐式类型转换 - - 去掉多余检查 - - 去掉重复数据整理 - - 减少 CPU/GPU 往返 - - 减少 Python 高频小开销 - - 如果你要,我下一步可以直接帮你定位“动作是在哪一层以 torch.Tensor 传进 - env 的”,给你指出具体应该改哪个函数。 \ No newline at end of file + - 重写 jepa.py:127 这段 rollout,去掉循环内 action_encoder + cat + - 在 eval.py:306 前加 compile warmup \ No newline at end of file diff --git a/tworoom_results.txt b/tworoom_results.txt index 35f5cba..bc74d7c 100644 --- a/tworoom_results.txt +++ b/tworoom_results.txt @@ -1282,3 +1282,489 @@ evaluation_time: 44.974061727523804 seconds inference_precision: fp16 inference_compile_target: predictor inference_compile_mode: reduce-overhead + +==== CONFIG ==== +cache_dir: null +solver: + _target_: stable_worldmodel.solver.CEMSolver + model: ??? + batch_size: 1 + num_samples: 300 + var_scale: 1.0 + n_steps: 30 + topk: 30 + device: cuda + seed: ${seed} +world: + env_name: swm/TwoRoom-v1 + num_envs: ${eval.num_eval} + max_episode_steps: 100 + history_size: 1 + frame_skip: 1 +seed: 42 +policy: two-room/tworoom/lejepa +inference_precision: fp16 +dataset: + stats: ${eval.dataset_name} + keys_to_cache: + - action + - proprio +plan_config: + horizon: 5 + receding_horizon: 5 + action_block: 5 +eval: + num_eval: 50 + goal_offset_steps: 25 + eval_budget: 50 + img_size: 224 + dataset_name: tworoom + callables: + - method: _set_state + args: + state: + value: proprio + - method: _set_goal_state + args: + goal_state: + value: goal_proprio +output: + filename: tworoom_results.txt + +==== RESULTS ==== +metrics: {'success_rate': 88.0, 'episode_successes': array([ True, False, True, False, True, True, True, True, False, + True, True, True, True, True, True, True, True, True, + True, True, True, True, True, True, True, True, True, + True, True, True, True, False, True, True, True, True, + True, True, False, False, True, True, True, True, True, + True, True, True, True, True]), 'seeds': None} +evaluation_time: 102.31317353248596 seconds +inference_precision: fp16 +inference_compile_target: predictor +inference_compile_mode: reduce-overhead + +==== CONFIG ==== +cache_dir: null +solver: + _target_: stable_worldmodel.solver.CEMSolver + model: ??? + batch_size: 1 + num_samples: 300 + var_scale: 1.0 + n_steps: 30 + topk: 30 + device: cuda + seed: ${seed} +world: + env_name: swm/TwoRoom-v1 + num_envs: ${eval.num_eval} + max_episode_steps: 100 + history_size: 1 + frame_skip: 1 +seed: 42 +policy: two-room/tworoom/lejepa +inference_precision: fp16 +dataset: + stats: ${eval.dataset_name} + keys_to_cache: + - action + - proprio +plan_config: + horizon: 5 + receding_horizon: 5 + action_block: 5 +eval: + num_eval: 50 + goal_offset_steps: 25 + eval_budget: 50 + img_size: 224 + dataset_name: tworoom + callables: + - method: _set_state + args: + state: + value: proprio + - method: _set_goal_state + args: + goal_state: + value: goal_proprio +output: + filename: tworoom_results.txt + +==== RESULTS ==== +metrics: {'success_rate': 88.0, 'episode_successes': array([ True, False, True, False, True, True, True, True, False, + True, True, True, True, True, True, True, True, True, + True, True, True, True, True, True, True, True, True, + True, True, True, True, False, True, True, True, True, + True, True, False, False, True, True, True, True, True, + True, True, True, True, True]), 'seeds': None} +evaluation_time: 45.355348110198975 seconds +inference_precision: fp16 +inference_compile_target: predictor +inference_compile_mode: reduce-overhead + +==== CONFIG ==== +cache_dir: null +solver: + _target_: stable_worldmodel.solver.CEMSolver + model: ??? + batch_size: 1 + num_samples: 300 + var_scale: 1.0 + n_steps: 30 + topk: 30 + device: cuda + seed: ${seed} +world: + env_name: swm/TwoRoom-v1 + num_envs: ${eval.num_eval} + max_episode_steps: 100 + history_size: 1 + frame_skip: 1 +seed: 42 +policy: two-room/tworoom/lejepa +inference_precision: fp16 +dataset: + stats: ${eval.dataset_name} + keys_to_cache: + - action + - proprio +plan_config: + horizon: 5 + receding_horizon: 5 + action_block: 5 +eval: + num_eval: 50 + goal_offset_steps: 25 + eval_budget: 50 + img_size: 224 + dataset_name: tworoom + callables: + - method: _set_state + args: + state: + value: proprio + - method: _set_goal_state + args: + goal_state: + value: goal_proprio +output: + filename: tworoom_results.txt +profile: + enabled: true + export_tensorboard: false + export_chrome_trace: false + +==== RESULTS ==== +metrics: {'success_rate': 88.0, 'episode_successes': array([ True, False, True, False, True, True, True, True, False, + True, True, True, True, True, True, True, True, True, + True, True, True, True, True, True, True, True, True, + True, True, True, True, False, True, True, True, True, + True, True, False, False, True, True, True, True, True, + True, True, True, True, True]), 'seeds': None} +evaluation_time: 110.91939687728882 seconds +inference_precision: fp16 +inference_compile_target: predictor +inference_compile_mode: reduce-overhead +profile_dir: /mnt/ASC1637/lewm_baseline/le-wm/torch_profile +profile_summary: /mnt/ASC1637/lewm_baseline/le-wm/torch_profile/key_averages.txt + +==== CONFIG ==== +cache_dir: null +solver: + _target_: stable_worldmodel.solver.CEMSolver + model: ??? + batch_size: 1 + num_samples: 300 + var_scale: 1.0 + n_steps: 30 + topk: 30 + device: cuda + seed: ${seed} +world: + env_name: swm/TwoRoom-v1 + num_envs: ${eval.num_eval} + max_episode_steps: 100 + history_size: 1 + frame_skip: 1 +seed: 42 +policy: two-room/tworoom/lejepa +inference_precision: fp16 +dataset: + stats: ${eval.dataset_name} + keys_to_cache: + - action + - proprio +plan_config: + horizon: 5 + receding_horizon: 5 + action_block: 5 +eval: + num_eval: 50 + goal_offset_steps: 25 + eval_budget: 50 + img_size: 224 + dataset_name: tworoom + callables: + - method: _set_state + args: + state: + value: proprio + - method: _set_goal_state + args: + goal_state: + value: goal_proprio +output: + filename: tworoom_results.txt + +==== RESULTS ==== +metrics: {'success_rate': 88.0, 'episode_successes': array([ True, False, True, False, True, True, True, True, False, + True, True, True, True, True, True, True, True, True, + True, True, True, False, True, True, True, True, True, + True, True, True, True, False, True, True, True, True, + True, True, False, True, True, True, True, True, True, + True, True, True, True, True]), 'seeds': None} +evaluation_time: 54.21496343612671 seconds +inference_precision: fp16 +inference_compile_target: predictor +inference_compile_mode: reduce-overhead + +==== CONFIG ==== +cache_dir: null +solver: + _target_: stable_worldmodel.solver.CEMSolver + model: ??? + batch_size: 1 + num_samples: 300 + var_scale: 1.0 + n_steps: 30 + topk: 30 + device: cuda + seed: ${seed} +world: + env_name: swm/TwoRoom-v1 + num_envs: ${eval.num_eval} + max_episode_steps: 100 + history_size: 1 + frame_skip: 1 +seed: 42 +policy: two-room/tworoom/lejepa +inference_precision: fp16 +dataset: + stats: ${eval.dataset_name} + keys_to_cache: + - action + - proprio +plan_config: + horizon: 5 + receding_horizon: 5 + action_block: 5 +eval: + num_eval: 50 + goal_offset_steps: 25 + eval_budget: 50 + img_size: 224 + dataset_name: tworoom + callables: + - method: _set_state + args: + state: + value: proprio + - method: _set_goal_state + args: + goal_state: + value: goal_proprio +output: + filename: tworoom_results.txt + +==== RESULTS ==== +metrics: {'success_rate': 90.0, 'episode_successes': array([ True, False, True, False, True, True, True, True, False, + True, True, True, True, True, True, True, True, True, + True, True, True, False, True, True, True, True, True, + True, True, True, True, False, True, True, True, True, + True, True, True, True, True, True, True, True, True, + True, True, True, True, True]), 'seeds': None} +evaluation_time: 43.69562244415283 seconds +inference_precision: fp16 +inference_compile_target: predictor +inference_compile_mode: reduce-overhead + +==== CONFIG ==== +cache_dir: null +solver: + _target_: stable_worldmodel.solver.CEMSolver + model: ??? + batch_size: 1 + num_samples: 300 + var_scale: 1.0 + n_steps: 30 + topk: 30 + device: cuda + seed: ${seed} +world: + env_name: swm/TwoRoom-v1 + num_envs: ${eval.num_eval} + max_episode_steps: 100 + history_size: 1 + frame_skip: 1 +seed: 42 +policy: two-room/tworoom/lejepa +inference_precision: fp16 +dataset: + stats: ${eval.dataset_name} + keys_to_cache: + - action + - proprio +plan_config: + horizon: 5 + receding_horizon: 5 + action_block: 5 +eval: + num_eval: 50 + goal_offset_steps: 25 + eval_budget: 50 + img_size: 224 + dataset_name: tworoom + callables: + - method: _set_state + args: + state: + value: proprio + - method: _set_goal_state + args: + goal_state: + value: goal_proprio +output: + filename: tworoom_results.txt + +==== RESULTS ==== +metrics: {'success_rate': 88.0, 'episode_successes': array([ True, False, True, False, True, True, True, True, False, + True, True, True, True, True, True, True, True, True, + True, True, True, False, True, True, True, True, True, + True, True, True, True, False, True, True, True, True, + True, True, False, True, True, True, True, True, True, + True, True, True, True, True]), 'seeds': None} +evaluation_time: 42.99847435951233 seconds +inference_precision: fp16 +inference_compile_target: predictor +inference_compile_mode: reduce-overhead + +==== CONFIG ==== +cache_dir: null +solver: + _target_: stable_worldmodel.solver.CEMSolver + model: ??? + batch_size: 1 + num_samples: 300 + var_scale: 1.0 + n_steps: 30 + topk: 30 + device: cuda + seed: ${seed} +world: + env_name: swm/TwoRoom-v1 + num_envs: ${eval.num_eval} + max_episode_steps: 100 + history_size: 1 + frame_skip: 1 +seed: 42 +policy: two-room/tworoom/lejepa +inference_precision: fp16 +dataset: + stats: ${eval.dataset_name} + keys_to_cache: + - action + - proprio +plan_config: + horizon: 5 + receding_horizon: 5 + action_block: 5 +eval: + num_eval: 50 + goal_offset_steps: 25 + eval_budget: 50 + img_size: 224 + dataset_name: tworoom + callables: + - method: _set_state + args: + state: + value: proprio + - method: _set_goal_state + args: + goal_state: + value: goal_proprio +output: + filename: tworoom_results.txt + +==== RESULTS ==== +metrics: {'success_rate': 88.0, 'episode_successes': array([ True, False, True, False, True, True, True, True, False, + True, True, True, True, True, True, True, True, True, + True, True, True, False, True, True, True, True, True, + True, True, True, True, False, True, True, True, True, + True, True, False, True, True, True, True, True, True, + True, True, True, True, True]), 'seeds': None} +evaluation_time: 43.14276576042175 seconds +inference_precision: fp16 +inference_compile_target: predictor +inference_compile_mode: reduce-overhead + +==== CONFIG ==== +cache_dir: null +solver: + _target_: stable_worldmodel.solver.CEMSolver + model: ??? + batch_size: 1 + num_samples: 300 + var_scale: 1.0 + n_steps: 30 + topk: 30 + device: cuda + seed: ${seed} +world: + env_name: swm/TwoRoom-v1 + num_envs: ${eval.num_eval} + max_episode_steps: 100 + history_size: 1 + frame_skip: 1 +seed: 42 +policy: two-room/tworoom/lejepa +inference_precision: fp16 +dataset: + stats: ${eval.dataset_name} + keys_to_cache: + - action + - proprio +plan_config: + horizon: 5 + receding_horizon: 5 + action_block: 5 +eval: + num_eval: 50 + goal_offset_steps: 25 + eval_budget: 50 + img_size: 224 + dataset_name: tworoom + callables: + - method: _set_state + args: + state: + value: proprio + - method: _set_goal_state + args: + goal_state: + value: goal_proprio +output: + filename: tworoom_results.txt + +==== RESULTS ==== +metrics: {'success_rate': 88.0, 'episode_successes': array([ True, False, True, False, True, True, True, True, False, + True, True, True, True, True, True, True, True, True, + True, True, True, False, True, True, True, True, True, + True, True, True, True, False, True, True, True, True, + True, True, False, True, True, True, True, True, True, + True, True, True, True, True]), 'seeds': None} +evaluation_time: 43.71034002304077 seconds +inference_precision: fp16 +inference_compile_target: predictor +inference_compile_mode: reduce-overhead