加入一个提前停止的机制 还有减少环境步中间步骤传递至cpu

This commit is contained in:
qihuanye
2026-05-18 00:48:59 +08:00
parent 113e591899
commit 28f2fba0e8
5 changed files with 138 additions and 16 deletions

View File

@@ -394,6 +394,16 @@ class WorldModelPolicy(BasePolicy):
'Solver must implement the Solver protocol'
)
def _normalize_active_mask(self, active_mask: Any) -> np.ndarray | None:
if active_mask is None:
return None
active_mask = np.asarray(active_mask, dtype=bool)
if active_mask.ndim != 1 or active_mask.shape[0] != self.env.num_envs:
raise ValueError(
f"active_mask must have shape ({self.env.num_envs},), got {active_mask.shape}"
)
return active_mask
def get_action(self, info_dict: dict, **kwargs: Any) -> np.ndarray:
"""Get action via planning with the world model.
@@ -405,17 +415,34 @@ class WorldModelPolicy(BasePolicy):
The selected action(s) as a numpy array.
"""
assert hasattr(self, 'env'), 'Environment not set for the policy'
assert 'pixels' in info_dict, "'pixels' must be provided in info_dict"
assert 'goal' in info_dict, "'goal' must be provided in info_dict"
info_dict = self._prepare_info(info_dict)
info_dict = self._move_info_to_device(info_dict, self.solver.device)
active_mask = self._normalize_active_mask(kwargs.get("active_mask"))
# need to replan if action buffer is empty
if len(self._action_buffer) == 0:
outputs = self.solver(info_dict, init_action=self._next_init)
assert 'pixels' in info_dict, "'pixels' must be provided in info_dict"
assert 'goal' in info_dict, "'goal' must be provided in info_dict"
info_dict = self._prepare_info(info_dict)
info_dict = self._move_info_to_device(info_dict, self.solver.device)
outputs = self.solver(
info_dict,
init_action=self._next_init,
active_mask=active_mask,
)
actions = outputs['actions'] # (num_envs, horizon, action_dim)
if active_mask is not None and actions.shape[0] != self.env.num_envs:
full_actions = torch.zeros(
self.env.num_envs,
actions.shape[1],
actions.shape[2],
dtype=actions.dtype,
device=actions.device,
)
full_actions[torch.as_tensor(active_mask, device=actions.device)] = actions
actions = full_actions
keep_horizon = self.cfg.receding_horizon
plan = actions[:, :keep_horizon]
rest = actions[:, keep_horizon:]
@@ -430,6 +457,16 @@ class WorldModelPolicy(BasePolicy):
action = self._action_buffer.popleft()
action = action.reshape(*self.env.action_space.shape)
if active_mask is not None:
if torch.is_tensor(action):
inactive_mask = torch.as_tensor(
~active_mask, device=action.device, dtype=torch.bool
)
action = action.clone()
action[inactive_mask] = 0
else:
action = np.array(action, copy=True)
action[~active_mask] = 0
if torch.is_tensor(action):
action = action.detach().cpu().numpy()
else: