加入一个提前停止的机制 还有减少环境步中间步骤传递至cpu
This commit is contained in:
@@ -394,6 +394,16 @@ class WorldModelPolicy(BasePolicy):
|
||||
'Solver must implement the Solver protocol'
|
||||
)
|
||||
|
||||
def _normalize_active_mask(self, active_mask: Any) -> np.ndarray | None:
|
||||
if active_mask is None:
|
||||
return None
|
||||
active_mask = np.asarray(active_mask, dtype=bool)
|
||||
if active_mask.ndim != 1 or active_mask.shape[0] != self.env.num_envs:
|
||||
raise ValueError(
|
||||
f"active_mask must have shape ({self.env.num_envs},), got {active_mask.shape}"
|
||||
)
|
||||
return active_mask
|
||||
|
||||
def get_action(self, info_dict: dict, **kwargs: Any) -> np.ndarray:
|
||||
"""Get action via planning with the world model.
|
||||
|
||||
@@ -405,17 +415,34 @@ class WorldModelPolicy(BasePolicy):
|
||||
The selected action(s) as a numpy array.
|
||||
"""
|
||||
assert hasattr(self, 'env'), 'Environment not set for the policy'
|
||||
assert 'pixels' in info_dict, "'pixels' must be provided in info_dict"
|
||||
assert 'goal' in info_dict, "'goal' must be provided in info_dict"
|
||||
|
||||
info_dict = self._prepare_info(info_dict)
|
||||
info_dict = self._move_info_to_device(info_dict, self.solver.device)
|
||||
active_mask = self._normalize_active_mask(kwargs.get("active_mask"))
|
||||
|
||||
# need to replan if action buffer is empty
|
||||
if len(self._action_buffer) == 0:
|
||||
outputs = self.solver(info_dict, init_action=self._next_init)
|
||||
assert 'pixels' in info_dict, "'pixels' must be provided in info_dict"
|
||||
assert 'goal' in info_dict, "'goal' must be provided in info_dict"
|
||||
|
||||
info_dict = self._prepare_info(info_dict)
|
||||
info_dict = self._move_info_to_device(info_dict, self.solver.device)
|
||||
|
||||
outputs = self.solver(
|
||||
info_dict,
|
||||
init_action=self._next_init,
|
||||
active_mask=active_mask,
|
||||
)
|
||||
|
||||
actions = outputs['actions'] # (num_envs, horizon, action_dim)
|
||||
if active_mask is not None and actions.shape[0] != self.env.num_envs:
|
||||
full_actions = torch.zeros(
|
||||
self.env.num_envs,
|
||||
actions.shape[1],
|
||||
actions.shape[2],
|
||||
dtype=actions.dtype,
|
||||
device=actions.device,
|
||||
)
|
||||
full_actions[torch.as_tensor(active_mask, device=actions.device)] = actions
|
||||
actions = full_actions
|
||||
|
||||
keep_horizon = self.cfg.receding_horizon
|
||||
plan = actions[:, :keep_horizon]
|
||||
rest = actions[:, keep_horizon:]
|
||||
@@ -430,6 +457,16 @@ class WorldModelPolicy(BasePolicy):
|
||||
|
||||
action = self._action_buffer.popleft()
|
||||
action = action.reshape(*self.env.action_space.shape)
|
||||
if active_mask is not None:
|
||||
if torch.is_tensor(action):
|
||||
inactive_mask = torch.as_tensor(
|
||||
~active_mask, device=action.device, dtype=torch.bool
|
||||
)
|
||||
action = action.clone()
|
||||
action[inactive_mask] = 0
|
||||
else:
|
||||
action = np.array(action, copy=True)
|
||||
action[~active_mask] = 0
|
||||
if torch.is_tensor(action):
|
||||
action = action.detach().cpu().numpy()
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user