加入一个提前停止的机制 还有减少环境步中间步骤传递至cpu
This commit is contained in:
@@ -394,6 +394,16 @@ class WorldModelPolicy(BasePolicy):
|
||||
'Solver must implement the Solver protocol'
|
||||
)
|
||||
|
||||
def _normalize_active_mask(self, active_mask: Any) -> np.ndarray | None:
|
||||
if active_mask is None:
|
||||
return None
|
||||
active_mask = np.asarray(active_mask, dtype=bool)
|
||||
if active_mask.ndim != 1 or active_mask.shape[0] != self.env.num_envs:
|
||||
raise ValueError(
|
||||
f"active_mask must have shape ({self.env.num_envs},), got {active_mask.shape}"
|
||||
)
|
||||
return active_mask
|
||||
|
||||
def get_action(self, info_dict: dict, **kwargs: Any) -> np.ndarray:
|
||||
"""Get action via planning with the world model.
|
||||
|
||||
@@ -405,17 +415,34 @@ class WorldModelPolicy(BasePolicy):
|
||||
The selected action(s) as a numpy array.
|
||||
"""
|
||||
assert hasattr(self, 'env'), 'Environment not set for the policy'
|
||||
assert 'pixels' in info_dict, "'pixels' must be provided in info_dict"
|
||||
assert 'goal' in info_dict, "'goal' must be provided in info_dict"
|
||||
|
||||
info_dict = self._prepare_info(info_dict)
|
||||
info_dict = self._move_info_to_device(info_dict, self.solver.device)
|
||||
active_mask = self._normalize_active_mask(kwargs.get("active_mask"))
|
||||
|
||||
# need to replan if action buffer is empty
|
||||
if len(self._action_buffer) == 0:
|
||||
outputs = self.solver(info_dict, init_action=self._next_init)
|
||||
assert 'pixels' in info_dict, "'pixels' must be provided in info_dict"
|
||||
assert 'goal' in info_dict, "'goal' must be provided in info_dict"
|
||||
|
||||
info_dict = self._prepare_info(info_dict)
|
||||
info_dict = self._move_info_to_device(info_dict, self.solver.device)
|
||||
|
||||
outputs = self.solver(
|
||||
info_dict,
|
||||
init_action=self._next_init,
|
||||
active_mask=active_mask,
|
||||
)
|
||||
|
||||
actions = outputs['actions'] # (num_envs, horizon, action_dim)
|
||||
if active_mask is not None and actions.shape[0] != self.env.num_envs:
|
||||
full_actions = torch.zeros(
|
||||
self.env.num_envs,
|
||||
actions.shape[1],
|
||||
actions.shape[2],
|
||||
dtype=actions.dtype,
|
||||
device=actions.device,
|
||||
)
|
||||
full_actions[torch.as_tensor(active_mask, device=actions.device)] = actions
|
||||
actions = full_actions
|
||||
|
||||
keep_horizon = self.cfg.receding_horizon
|
||||
plan = actions[:, :keep_horizon]
|
||||
rest = actions[:, keep_horizon:]
|
||||
@@ -430,6 +457,16 @@ class WorldModelPolicy(BasePolicy):
|
||||
|
||||
action = self._action_buffer.popleft()
|
||||
action = action.reshape(*self.env.action_space.shape)
|
||||
if active_mask is not None:
|
||||
if torch.is_tensor(action):
|
||||
inactive_mask = torch.as_tensor(
|
||||
~active_mask, device=action.device, dtype=torch.bool
|
||||
)
|
||||
action = action.clone()
|
||||
action[inactive_mask] = 0
|
||||
else:
|
||||
action = np.array(action, copy=True)
|
||||
action[~active_mask] = 0
|
||||
if torch.is_tensor(action):
|
||||
action = action.detach().cpu().numpy()
|
||||
else:
|
||||
|
||||
@@ -76,6 +76,24 @@ class CEMSolver:
|
||||
"""Make solver callable, forwarding to solve()."""
|
||||
return self.solve(*args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def _normalize_active_mask(
|
||||
active_mask: torch.Tensor | np.ndarray | None,
|
||||
n_envs: int,
|
||||
device: torch.device,
|
||||
) -> torch.Tensor | None:
|
||||
if active_mask is None:
|
||||
return None
|
||||
if not torch.is_tensor(active_mask):
|
||||
active_mask = torch.as_tensor(active_mask, dtype=torch.bool, device=device)
|
||||
else:
|
||||
active_mask = active_mask.to(device=device, dtype=torch.bool)
|
||||
if active_mask.ndim != 1 or active_mask.shape[0] != n_envs:
|
||||
raise ValueError(
|
||||
f"active_mask must have shape ({n_envs},), got {tuple(active_mask.shape)}"
|
||||
)
|
||||
return active_mask
|
||||
|
||||
def init_action_distrib(
|
||||
self, actions: torch.Tensor | None = None
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
@@ -103,7 +121,10 @@ class CEMSolver:
|
||||
|
||||
@torch.inference_mode()
|
||||
def solve(
|
||||
self, info_dict: dict, init_action: torch.Tensor | None = None
|
||||
self,
|
||||
info_dict: dict,
|
||||
init_action: torch.Tensor | None = None,
|
||||
active_mask: torch.Tensor | np.ndarray | None = None,
|
||||
) -> dict:
|
||||
"""Solve the planning problem using Cross Entropy Method."""
|
||||
start_time = time.time()
|
||||
@@ -119,6 +140,17 @@ class CEMSolver:
|
||||
mean = mean.to(self.device, non_blocking=True)
|
||||
if var.device != torch.device(self.device):
|
||||
var = var.to(self.device, non_blocking=True)
|
||||
active_mask = self._normalize_active_mask(
|
||||
active_mask, self.n_envs, torch.device(self.device)
|
||||
)
|
||||
|
||||
if active_mask is not None and not torch.any(active_mask):
|
||||
return {
|
||||
"costs": [],
|
||||
"actions": mean.detach(),
|
||||
"mean": [mean.detach()],
|
||||
"var": [var.detach()],
|
||||
}
|
||||
|
||||
total_envs = self.n_envs
|
||||
|
||||
@@ -148,6 +180,25 @@ class CEMSolver:
|
||||
v_batch = np.repeat(v_batch[:, None, ...], self.num_samples, axis=1)
|
||||
expanded_infos[k] = v_batch
|
||||
|
||||
if active_mask is not None:
|
||||
batch_mask = active_mask[start_idx:end_idx]
|
||||
if not torch.any(batch_mask):
|
||||
outputs["costs"].append(
|
||||
torch.full((current_bs,), float("nan"), device=self.device)
|
||||
)
|
||||
continue
|
||||
active_local = torch.nonzero(batch_mask, as_tuple=False).squeeze(1)
|
||||
active_local_np = active_local.detach().cpu().numpy()
|
||||
batch_mean = batch_mean[active_local]
|
||||
batch_var = batch_var[active_local]
|
||||
expanded_infos = {
|
||||
k: (v[active_local] if torch.is_tensor(v) else v[active_local_np])
|
||||
for k, v in expanded_infos.items()
|
||||
}
|
||||
current_bs = int(active_local.numel())
|
||||
else:
|
||||
active_local = None
|
||||
|
||||
# Optimization Loop
|
||||
final_batch_cost = None
|
||||
batch_indices = torch.arange(current_bs, device=self.device).unsqueeze(1)
|
||||
@@ -198,13 +249,26 @@ class CEMSolver:
|
||||
final_batch_cost = topk_vals.mean(dim=1).detach()
|
||||
|
||||
# Write results back to global storage
|
||||
mean[start_idx:end_idx] = batch_mean
|
||||
var[start_idx:end_idx] = batch_var
|
||||
if active_mask is not None:
|
||||
global_indices = start_idx + active_local
|
||||
mean[global_indices] = batch_mean
|
||||
var[global_indices] = batch_var
|
||||
batch_costs = torch.full(
|
||||
(end_idx - start_idx,), float("nan"), device=self.device
|
||||
)
|
||||
batch_costs[active_local] = final_batch_cost
|
||||
else:
|
||||
mean[start_idx:end_idx] = batch_mean
|
||||
var[start_idx:end_idx] = batch_var
|
||||
batch_costs = final_batch_cost
|
||||
|
||||
# Store history/metadata
|
||||
outputs["costs"].append(final_batch_cost)
|
||||
outputs["costs"].append(batch_costs)
|
||||
|
||||
outputs["costs"] = torch.cat(outputs["costs"]).cpu().tolist()
|
||||
if outputs["costs"]:
|
||||
outputs["costs"] = torch.cat(outputs["costs"]).cpu().tolist()
|
||||
else:
|
||||
outputs["costs"] = []
|
||||
outputs["actions"] = mean.detach()
|
||||
outputs["mean"] = [mean.detach()]
|
||||
outputs["var"] = [var.detach()]
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from typing import Any, Protocol, runtime_checkable
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
@@ -61,7 +62,12 @@ class Solver(Protocol):
|
||||
"""Planning horizon length in timesteps."""
|
||||
...
|
||||
|
||||
def solve(self, info_dict: dict, init_action: torch.Tensor | None = None) -> dict:
|
||||
def solve(
|
||||
self,
|
||||
info_dict: dict,
|
||||
init_action: torch.Tensor | None = None,
|
||||
active_mask: torch.Tensor | np.ndarray | None = None,
|
||||
) -> dict:
|
||||
"""Solve the planning optimization problem to find optimal actions.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -984,17 +984,32 @@ class World:
|
||||
)
|
||||
|
||||
# run normal evaluation for eval_budget and record video
|
||||
active_mask = np.ones(self.num_envs, dtype=bool)
|
||||
last_eval_step = 0
|
||||
for i in range(eval_budget):
|
||||
video_frames[:, i] = self.infos['pixels'][:, -1]
|
||||
last_eval_step = i
|
||||
self.infos.update(goal_step)
|
||||
self.step()
|
||||
actions = self.policy.get_action(self.infos, active_mask=active_mask)
|
||||
(
|
||||
self.states,
|
||||
self.rewards,
|
||||
self.terminateds,
|
||||
self.truncateds,
|
||||
self.infos,
|
||||
) = self.envs.step(actions)
|
||||
results['episode_successes'] = np.logical_or(
|
||||
results['episode_successes'], self.terminateds
|
||||
)
|
||||
active_mask = np.logical_not(results['episode_successes'])
|
||||
if not np.any(active_mask):
|
||||
break
|
||||
# for auto-reset
|
||||
self.envs.unwrapped._autoreset_envs = np.zeros((self.num_envs,))
|
||||
|
||||
video_frames[:, -1] = self.infos['pixels'][:, -1]
|
||||
video_frames[:, last_eval_step] = self.infos['pixels'][:, -1]
|
||||
if last_eval_step + 1 < eval_budget:
|
||||
video_frames[:, last_eval_step + 1 :] = video_frames[:, last_eval_step : last_eval_step + 1]
|
||||
|
||||
n_episodes = len(episodes_idx)
|
||||
|
||||
|
||||
@@ -2,9 +2,9 @@ _target_: stable_worldmodel.solver.CEMSolver
|
||||
model: ???
|
||||
batch_size: 16
|
||||
# Original defaults: num_samples=300, n_steps=30, topk=30, batch_size=8.
|
||||
num_samples: 64
|
||||
num_samples: 300
|
||||
var_scale: 1.0
|
||||
n_steps: 10
|
||||
n_steps: 30
|
||||
topk: 8
|
||||
device: "cuda"
|
||||
seed: ${seed}
|
||||
|
||||
Reference in New Issue
Block a user