diff --git a/.venv/lib/python3.10/site-packages/stable_worldmodel/policy.py b/.venv/lib/python3.10/site-packages/stable_worldmodel/policy.py index 1d8c85e..63c408a 100644 --- a/.venv/lib/python3.10/site-packages/stable_worldmodel/policy.py +++ b/.venv/lib/python3.10/site-packages/stable_worldmodel/policy.py @@ -394,6 +394,16 @@ class WorldModelPolicy(BasePolicy): 'Solver must implement the Solver protocol' ) + def _normalize_active_mask(self, active_mask: Any) -> np.ndarray | None: + if active_mask is None: + return None + active_mask = np.asarray(active_mask, dtype=bool) + if active_mask.ndim != 1 or active_mask.shape[0] != self.env.num_envs: + raise ValueError( + f"active_mask must have shape ({self.env.num_envs},), got {active_mask.shape}" + ) + return active_mask + def get_action(self, info_dict: dict, **kwargs: Any) -> np.ndarray: """Get action via planning with the world model. @@ -405,17 +415,34 @@ class WorldModelPolicy(BasePolicy): The selected action(s) as a numpy array. """ assert hasattr(self, 'env'), 'Environment not set for the policy' - assert 'pixels' in info_dict, "'pixels' must be provided in info_dict" - assert 'goal' in info_dict, "'goal' must be provided in info_dict" - - info_dict = self._prepare_info(info_dict) - info_dict = self._move_info_to_device(info_dict, self.solver.device) + active_mask = self._normalize_active_mask(kwargs.get("active_mask")) # need to replan if action buffer is empty if len(self._action_buffer) == 0: - outputs = self.solver(info_dict, init_action=self._next_init) + assert 'pixels' in info_dict, "'pixels' must be provided in info_dict" + assert 'goal' in info_dict, "'goal' must be provided in info_dict" + + info_dict = self._prepare_info(info_dict) + info_dict = self._move_info_to_device(info_dict, self.solver.device) + + outputs = self.solver( + info_dict, + init_action=self._next_init, + active_mask=active_mask, + ) actions = outputs['actions'] # (num_envs, horizon, action_dim) + if active_mask is not None and actions.shape[0] != self.env.num_envs: + full_actions = torch.zeros( + self.env.num_envs, + actions.shape[1], + actions.shape[2], + dtype=actions.dtype, + device=actions.device, + ) + full_actions[torch.as_tensor(active_mask, device=actions.device)] = actions + actions = full_actions + keep_horizon = self.cfg.receding_horizon plan = actions[:, :keep_horizon] rest = actions[:, keep_horizon:] @@ -430,6 +457,16 @@ class WorldModelPolicy(BasePolicy): action = self._action_buffer.popleft() action = action.reshape(*self.env.action_space.shape) + if active_mask is not None: + if torch.is_tensor(action): + inactive_mask = torch.as_tensor( + ~active_mask, device=action.device, dtype=torch.bool + ) + action = action.clone() + action[inactive_mask] = 0 + else: + action = np.array(action, copy=True) + action[~active_mask] = 0 if torch.is_tensor(action): action = action.detach().cpu().numpy() else: diff --git a/.venv/lib/python3.10/site-packages/stable_worldmodel/solver/cem.py b/.venv/lib/python3.10/site-packages/stable_worldmodel/solver/cem.py index 328d141..346d699 100644 --- a/.venv/lib/python3.10/site-packages/stable_worldmodel/solver/cem.py +++ b/.venv/lib/python3.10/site-packages/stable_worldmodel/solver/cem.py @@ -76,6 +76,24 @@ class CEMSolver: """Make solver callable, forwarding to solve().""" return self.solve(*args, **kwargs) + @staticmethod + def _normalize_active_mask( + active_mask: torch.Tensor | np.ndarray | None, + n_envs: int, + device: torch.device, + ) -> torch.Tensor | None: + if active_mask is None: + return None + if not torch.is_tensor(active_mask): + active_mask = torch.as_tensor(active_mask, dtype=torch.bool, device=device) + else: + active_mask = active_mask.to(device=device, dtype=torch.bool) + if active_mask.ndim != 1 or active_mask.shape[0] != n_envs: + raise ValueError( + f"active_mask must have shape ({n_envs},), got {tuple(active_mask.shape)}" + ) + return active_mask + def init_action_distrib( self, actions: torch.Tensor | None = None ) -> tuple[torch.Tensor, torch.Tensor]: @@ -103,7 +121,10 @@ class CEMSolver: @torch.inference_mode() def solve( - self, info_dict: dict, init_action: torch.Tensor | None = None + self, + info_dict: dict, + init_action: torch.Tensor | None = None, + active_mask: torch.Tensor | np.ndarray | None = None, ) -> dict: """Solve the planning problem using Cross Entropy Method.""" start_time = time.time() @@ -119,6 +140,17 @@ class CEMSolver: mean = mean.to(self.device, non_blocking=True) if var.device != torch.device(self.device): var = var.to(self.device, non_blocking=True) + active_mask = self._normalize_active_mask( + active_mask, self.n_envs, torch.device(self.device) + ) + + if active_mask is not None and not torch.any(active_mask): + return { + "costs": [], + "actions": mean.detach(), + "mean": [mean.detach()], + "var": [var.detach()], + } total_envs = self.n_envs @@ -148,6 +180,25 @@ class CEMSolver: v_batch = np.repeat(v_batch[:, None, ...], self.num_samples, axis=1) expanded_infos[k] = v_batch + if active_mask is not None: + batch_mask = active_mask[start_idx:end_idx] + if not torch.any(batch_mask): + outputs["costs"].append( + torch.full((current_bs,), float("nan"), device=self.device) + ) + continue + active_local = torch.nonzero(batch_mask, as_tuple=False).squeeze(1) + active_local_np = active_local.detach().cpu().numpy() + batch_mean = batch_mean[active_local] + batch_var = batch_var[active_local] + expanded_infos = { + k: (v[active_local] if torch.is_tensor(v) else v[active_local_np]) + for k, v in expanded_infos.items() + } + current_bs = int(active_local.numel()) + else: + active_local = None + # Optimization Loop final_batch_cost = None batch_indices = torch.arange(current_bs, device=self.device).unsqueeze(1) @@ -198,13 +249,26 @@ class CEMSolver: final_batch_cost = topk_vals.mean(dim=1).detach() # Write results back to global storage - mean[start_idx:end_idx] = batch_mean - var[start_idx:end_idx] = batch_var + if active_mask is not None: + global_indices = start_idx + active_local + mean[global_indices] = batch_mean + var[global_indices] = batch_var + batch_costs = torch.full( + (end_idx - start_idx,), float("nan"), device=self.device + ) + batch_costs[active_local] = final_batch_cost + else: + mean[start_idx:end_idx] = batch_mean + var[start_idx:end_idx] = batch_var + batch_costs = final_batch_cost # Store history/metadata - outputs["costs"].append(final_batch_cost) + outputs["costs"].append(batch_costs) - outputs["costs"] = torch.cat(outputs["costs"]).cpu().tolist() + if outputs["costs"]: + outputs["costs"] = torch.cat(outputs["costs"]).cpu().tolist() + else: + outputs["costs"] = [] outputs["actions"] = mean.detach() outputs["mean"] = [mean.detach()] outputs["var"] = [var.detach()] diff --git a/.venv/lib/python3.10/site-packages/stable_worldmodel/solver/solver.py b/.venv/lib/python3.10/site-packages/stable_worldmodel/solver/solver.py index f12a62e..0ce8876 100644 --- a/.venv/lib/python3.10/site-packages/stable_worldmodel/solver/solver.py +++ b/.venv/lib/python3.10/site-packages/stable_worldmodel/solver/solver.py @@ -1,6 +1,7 @@ from typing import Any, Protocol, runtime_checkable import gymnasium as gym +import numpy as np import torch @@ -61,7 +62,12 @@ class Solver(Protocol): """Planning horizon length in timesteps.""" ... - def solve(self, info_dict: dict, init_action: torch.Tensor | None = None) -> dict: + def solve( + self, + info_dict: dict, + init_action: torch.Tensor | None = None, + active_mask: torch.Tensor | np.ndarray | None = None, + ) -> dict: """Solve the planning optimization problem to find optimal actions. Args: diff --git a/.venv/lib/python3.10/site-packages/stable_worldmodel/world.py b/.venv/lib/python3.10/site-packages/stable_worldmodel/world.py index f3c348b..8a6334f 100644 --- a/.venv/lib/python3.10/site-packages/stable_worldmodel/world.py +++ b/.venv/lib/python3.10/site-packages/stable_worldmodel/world.py @@ -984,17 +984,32 @@ class World: ) # run normal evaluation for eval_budget and record video + active_mask = np.ones(self.num_envs, dtype=bool) + last_eval_step = 0 for i in range(eval_budget): video_frames[:, i] = self.infos['pixels'][:, -1] + last_eval_step = i self.infos.update(goal_step) - self.step() + actions = self.policy.get_action(self.infos, active_mask=active_mask) + ( + self.states, + self.rewards, + self.terminateds, + self.truncateds, + self.infos, + ) = self.envs.step(actions) results['episode_successes'] = np.logical_or( results['episode_successes'], self.terminateds ) + active_mask = np.logical_not(results['episode_successes']) + if not np.any(active_mask): + break # for auto-reset self.envs.unwrapped._autoreset_envs = np.zeros((self.num_envs,)) - video_frames[:, -1] = self.infos['pixels'][:, -1] + video_frames[:, last_eval_step] = self.infos['pixels'][:, -1] + if last_eval_step + 1 < eval_budget: + video_frames[:, last_eval_step + 1 :] = video_frames[:, last_eval_step : last_eval_step + 1] n_episodes = len(episodes_idx) diff --git a/config/eval/solver/cem.yaml b/config/eval/solver/cem.yaml index add68d8..807fb3b 100644 --- a/config/eval/solver/cem.yaml +++ b/config/eval/solver/cem.yaml @@ -2,9 +2,9 @@ _target_: stable_worldmodel.solver.CEMSolver model: ??? batch_size: 16 # Original defaults: num_samples=300, n_steps=30, topk=30, batch_size=8. -num_samples: 64 +num_samples: 300 var_scale: 1.0 -n_steps: 10 +n_steps: 30 topk: 8 device: "cuda" seed: ${seed}