继续做了通用性能优化,重点从 jepa.py 热路径转到实际的 stable_worldmodel

solver/policy 边界:去掉 CEM 每轮 cpu().tolist() 和结果过早回 CPU,把
  plan/warm-start 保持在 GPU,只在 env.step 前最后一步转成 numpy,同时补
  了输入张量的 contiguous 处理;
This commit is contained in:
qihuanye
2026-04-09 12:33:50 +00:00
parent 995cd8cfec
commit 25e4ddb628
4 changed files with 432 additions and 29 deletions

View File

@@ -122,8 +122,12 @@ class BasePolicy:
) -> dict[str, Any]:
target = torch.device(device)
for k, v in info_dict.items():
if torch.is_tensor(v) and v.device != target:
info_dict[k] = v.to(target, non_blocking=True)
if torch.is_tensor(v):
if v.device != target:
v = v.to(target, non_blocking=True)
if not v.is_contiguous():
v = v.contiguous()
info_dict[k] = v
return info_dict
def _prepare_info(self, info_dict: dict) -> dict[str, torch.Tensor]:
@@ -415,18 +419,21 @@ class WorldModelPolicy(BasePolicy):
keep_horizon = self.cfg.receding_horizon
plan = actions[:, :keep_horizon]
rest = actions[:, keep_horizon:]
self._next_init = rest if self.cfg.warm_start else None
self._next_init = rest.contiguous() if self.cfg.warm_start else None
# frameskip back to timestep
plan = plan.reshape(
self.env.num_envs, self.flatten_receding_horizon, -1
)
).contiguous()
self._action_buffer.extend(plan.transpose(0, 1))
self._action_buffer.extend(plan.transpose(0, 1).unbind(0))
action = self._action_buffer.popleft()
action = action.reshape(*self.env.action_space.shape)
action = action.numpy()
if torch.is_tensor(action):
action = action.detach().cpu().numpy()
else:
action = np.asarray(action)
# post-process action
if 'action' in self.process: