加入一个提前停止的机制 还有减少环境步中间步骤传递至cpu

This commit is contained in:
qihuanye
2026-05-18 00:48:59 +08:00
parent 113e591899
commit 28f2fba0e8
5 changed files with 138 additions and 16 deletions

View File

@@ -394,6 +394,16 @@ class WorldModelPolicy(BasePolicy):
'Solver must implement the Solver protocol'
)
def _normalize_active_mask(self, active_mask: Any) -> np.ndarray | None:
if active_mask is None:
return None
active_mask = np.asarray(active_mask, dtype=bool)
if active_mask.ndim != 1 or active_mask.shape[0] != self.env.num_envs:
raise ValueError(
f"active_mask must have shape ({self.env.num_envs},), got {active_mask.shape}"
)
return active_mask
def get_action(self, info_dict: dict, **kwargs: Any) -> np.ndarray:
"""Get action via planning with the world model.
@@ -405,17 +415,34 @@ class WorldModelPolicy(BasePolicy):
The selected action(s) as a numpy array.
"""
assert hasattr(self, 'env'), 'Environment not set for the policy'
assert 'pixels' in info_dict, "'pixels' must be provided in info_dict"
assert 'goal' in info_dict, "'goal' must be provided in info_dict"
info_dict = self._prepare_info(info_dict)
info_dict = self._move_info_to_device(info_dict, self.solver.device)
active_mask = self._normalize_active_mask(kwargs.get("active_mask"))
# need to replan if action buffer is empty
if len(self._action_buffer) == 0:
outputs = self.solver(info_dict, init_action=self._next_init)
assert 'pixels' in info_dict, "'pixels' must be provided in info_dict"
assert 'goal' in info_dict, "'goal' must be provided in info_dict"
info_dict = self._prepare_info(info_dict)
info_dict = self._move_info_to_device(info_dict, self.solver.device)
outputs = self.solver(
info_dict,
init_action=self._next_init,
active_mask=active_mask,
)
actions = outputs['actions'] # (num_envs, horizon, action_dim)
if active_mask is not None and actions.shape[0] != self.env.num_envs:
full_actions = torch.zeros(
self.env.num_envs,
actions.shape[1],
actions.shape[2],
dtype=actions.dtype,
device=actions.device,
)
full_actions[torch.as_tensor(active_mask, device=actions.device)] = actions
actions = full_actions
keep_horizon = self.cfg.receding_horizon
plan = actions[:, :keep_horizon]
rest = actions[:, keep_horizon:]
@@ -430,6 +457,16 @@ class WorldModelPolicy(BasePolicy):
action = self._action_buffer.popleft()
action = action.reshape(*self.env.action_space.shape)
if active_mask is not None:
if torch.is_tensor(action):
inactive_mask = torch.as_tensor(
~active_mask, device=action.device, dtype=torch.bool
)
action = action.clone()
action[inactive_mask] = 0
else:
action = np.array(action, copy=True)
action[~active_mask] = 0
if torch.is_tensor(action):
action = action.detach().cpu().numpy()
else:

View File

@@ -76,6 +76,24 @@ class CEMSolver:
"""Make solver callable, forwarding to solve()."""
return self.solve(*args, **kwargs)
@staticmethod
def _normalize_active_mask(
active_mask: torch.Tensor | np.ndarray | None,
n_envs: int,
device: torch.device,
) -> torch.Tensor | None:
if active_mask is None:
return None
if not torch.is_tensor(active_mask):
active_mask = torch.as_tensor(active_mask, dtype=torch.bool, device=device)
else:
active_mask = active_mask.to(device=device, dtype=torch.bool)
if active_mask.ndim != 1 or active_mask.shape[0] != n_envs:
raise ValueError(
f"active_mask must have shape ({n_envs},), got {tuple(active_mask.shape)}"
)
return active_mask
def init_action_distrib(
self, actions: torch.Tensor | None = None
) -> tuple[torch.Tensor, torch.Tensor]:
@@ -103,7 +121,10 @@ class CEMSolver:
@torch.inference_mode()
def solve(
self, info_dict: dict, init_action: torch.Tensor | None = None
self,
info_dict: dict,
init_action: torch.Tensor | None = None,
active_mask: torch.Tensor | np.ndarray | None = None,
) -> dict:
"""Solve the planning problem using Cross Entropy Method."""
start_time = time.time()
@@ -119,6 +140,17 @@ class CEMSolver:
mean = mean.to(self.device, non_blocking=True)
if var.device != torch.device(self.device):
var = var.to(self.device, non_blocking=True)
active_mask = self._normalize_active_mask(
active_mask, self.n_envs, torch.device(self.device)
)
if active_mask is not None and not torch.any(active_mask):
return {
"costs": [],
"actions": mean.detach(),
"mean": [mean.detach()],
"var": [var.detach()],
}
total_envs = self.n_envs
@@ -148,6 +180,25 @@ class CEMSolver:
v_batch = np.repeat(v_batch[:, None, ...], self.num_samples, axis=1)
expanded_infos[k] = v_batch
if active_mask is not None:
batch_mask = active_mask[start_idx:end_idx]
if not torch.any(batch_mask):
outputs["costs"].append(
torch.full((current_bs,), float("nan"), device=self.device)
)
continue
active_local = torch.nonzero(batch_mask, as_tuple=False).squeeze(1)
active_local_np = active_local.detach().cpu().numpy()
batch_mean = batch_mean[active_local]
batch_var = batch_var[active_local]
expanded_infos = {
k: (v[active_local] if torch.is_tensor(v) else v[active_local_np])
for k, v in expanded_infos.items()
}
current_bs = int(active_local.numel())
else:
active_local = None
# Optimization Loop
final_batch_cost = None
batch_indices = torch.arange(current_bs, device=self.device).unsqueeze(1)
@@ -198,13 +249,26 @@ class CEMSolver:
final_batch_cost = topk_vals.mean(dim=1).detach()
# Write results back to global storage
mean[start_idx:end_idx] = batch_mean
var[start_idx:end_idx] = batch_var
if active_mask is not None:
global_indices = start_idx + active_local
mean[global_indices] = batch_mean
var[global_indices] = batch_var
batch_costs = torch.full(
(end_idx - start_idx,), float("nan"), device=self.device
)
batch_costs[active_local] = final_batch_cost
else:
mean[start_idx:end_idx] = batch_mean
var[start_idx:end_idx] = batch_var
batch_costs = final_batch_cost
# Store history/metadata
outputs["costs"].append(final_batch_cost)
outputs["costs"].append(batch_costs)
outputs["costs"] = torch.cat(outputs["costs"]).cpu().tolist()
if outputs["costs"]:
outputs["costs"] = torch.cat(outputs["costs"]).cpu().tolist()
else:
outputs["costs"] = []
outputs["actions"] = mean.detach()
outputs["mean"] = [mean.detach()]
outputs["var"] = [var.detach()]

View File

@@ -1,6 +1,7 @@
from typing import Any, Protocol, runtime_checkable
import gymnasium as gym
import numpy as np
import torch
@@ -61,7 +62,12 @@ class Solver(Protocol):
"""Planning horizon length in timesteps."""
...
def solve(self, info_dict: dict, init_action: torch.Tensor | None = None) -> dict:
def solve(
self,
info_dict: dict,
init_action: torch.Tensor | None = None,
active_mask: torch.Tensor | np.ndarray | None = None,
) -> dict:
"""Solve the planning optimization problem to find optimal actions.
Args:

View File

@@ -984,17 +984,32 @@ class World:
)
# run normal evaluation for eval_budget and record video
active_mask = np.ones(self.num_envs, dtype=bool)
last_eval_step = 0
for i in range(eval_budget):
video_frames[:, i] = self.infos['pixels'][:, -1]
last_eval_step = i
self.infos.update(goal_step)
self.step()
actions = self.policy.get_action(self.infos, active_mask=active_mask)
(
self.states,
self.rewards,
self.terminateds,
self.truncateds,
self.infos,
) = self.envs.step(actions)
results['episode_successes'] = np.logical_or(
results['episode_successes'], self.terminateds
)
active_mask = np.logical_not(results['episode_successes'])
if not np.any(active_mask):
break
# for auto-reset
self.envs.unwrapped._autoreset_envs = np.zeros((self.num_envs,))
video_frames[:, -1] = self.infos['pixels'][:, -1]
video_frames[:, last_eval_step] = self.infos['pixels'][:, -1]
if last_eval_step + 1 < eval_budget:
video_frames[:, last_eval_step + 1 :] = video_frames[:, last_eval_step : last_eval_step + 1]
n_episodes = len(episodes_idx)

View File

@@ -2,9 +2,9 @@ _target_: stable_worldmodel.solver.CEMSolver
model: ???
batch_size: 16
# Original defaults: num_samples=300, n_steps=30, topk=30, batch_size=8.
num_samples: 64
num_samples: 300
var_scale: 1.0
n_steps: 10
n_steps: 30
topk: 8
device: "cuda"
seed: ${seed}