继续做了通用性能优化,重点从 jepa.py 热路径转到实际的 stable_worldmodel
solver/policy 边界:去掉 CEM 每轮 cpu().tolist() 和结果过早回 CPU,把 plan/warm-start 保持在 GPU,只在 env.step 前最后一步转成 numpy,同时补 了输入张量的 contiguous 处理;
This commit is contained in:
@@ -122,8 +122,12 @@ class BasePolicy:
|
||||
) -> dict[str, Any]:
|
||||
target = torch.device(device)
|
||||
for k, v in info_dict.items():
|
||||
if torch.is_tensor(v) and v.device != target:
|
||||
info_dict[k] = v.to(target, non_blocking=True)
|
||||
if torch.is_tensor(v):
|
||||
if v.device != target:
|
||||
v = v.to(target, non_blocking=True)
|
||||
if not v.is_contiguous():
|
||||
v = v.contiguous()
|
||||
info_dict[k] = v
|
||||
return info_dict
|
||||
|
||||
def _prepare_info(self, info_dict: dict) -> dict[str, torch.Tensor]:
|
||||
@@ -415,18 +419,21 @@ class WorldModelPolicy(BasePolicy):
|
||||
keep_horizon = self.cfg.receding_horizon
|
||||
plan = actions[:, :keep_horizon]
|
||||
rest = actions[:, keep_horizon:]
|
||||
self._next_init = rest if self.cfg.warm_start else None
|
||||
self._next_init = rest.contiguous() if self.cfg.warm_start else None
|
||||
|
||||
# frameskip back to timestep
|
||||
plan = plan.reshape(
|
||||
self.env.num_envs, self.flatten_receding_horizon, -1
|
||||
)
|
||||
).contiguous()
|
||||
|
||||
self._action_buffer.extend(plan.transpose(0, 1))
|
||||
self._action_buffer.extend(plan.transpose(0, 1).unbind(0))
|
||||
|
||||
action = self._action_buffer.popleft()
|
||||
action = action.reshape(*self.env.action_space.shape)
|
||||
action = action.numpy()
|
||||
if torch.is_tensor(action):
|
||||
action = action.detach().cpu().numpy()
|
||||
else:
|
||||
action = np.asarray(action)
|
||||
|
||||
# post-process action
|
||||
if 'action' in self.process:
|
||||
|
||||
Reference in New Issue
Block a user