From 12ba4f4352dc9e91acb6e834c74c979fafdcb884 Mon Sep 17 00:00:00 2001 From: qihuanye Date: Wed, 8 Apr 2026 13:01:24 +0000 Subject: [PATCH] Optimize CEM input transfers before sample expansion --- .../stable_worldmodel/solver/cem.py | 201 ++++++++++++++++++ tworoom_results.txt | 171 +++++++++++++++ 2 files changed, 372 insertions(+) create mode 100644 .venv/lib/python3.10/site-packages/stable_worldmodel/solver/cem.py diff --git a/.venv/lib/python3.10/site-packages/stable_worldmodel/solver/cem.py b/.venv/lib/python3.10/site-packages/stable_worldmodel/solver/cem.py new file mode 100644 index 0000000..77fa14a --- /dev/null +++ b/.venv/lib/python3.10/site-packages/stable_worldmodel/solver/cem.py @@ -0,0 +1,201 @@ +"""Cross Entropy Method solver for model-based planning.""" + +import time +from typing import Any + +import gymnasium as gym +import numpy as np +import torch +from gymnasium.spaces import Box +from loguru import logger as logging + +from .solver import Costable + + +class CEMSolver: + """Cross Entropy Method solver for action optimization. + + Args: + model: World model implementing the Costable protocol. + batch_size: Number of environments to process in parallel. + num_samples: Number of action candidates to sample per iteration. + var_scale: Initial variance scale for the action distribution. + n_steps: Number of CEM iterations. + topk: Number of elite samples to keep for distribution update. + device: Device for tensor computations. + seed: Random seed for reproducibility. + """ + + def __init__( + self, + model: Costable, + batch_size: int = 1, + num_samples: int = 300, + var_scale: float = 1, + n_steps: int = 30, + topk: int = 30, + device: str | torch.device = "cpu", + seed: int = 1234, + ) -> None: + self.model = model + self.batch_size = batch_size + self.var_scale = var_scale + self.num_samples = num_samples + self.n_steps = n_steps + self.topk = topk + self.device = device + self.torch_gen = torch.Generator(device=device).manual_seed(seed) + + def configure(self, *, action_space: gym.Space, n_envs: int, config: Any) -> None: + """Configure the solver with environment specifications.""" + self._action_space = action_space + self._n_envs = n_envs + self._config = config + self._action_dim = int(np.prod(action_space.shape[1:])) + self._configured = True + + if not isinstance(action_space, Box): + logging.warning(f"Action space is discrete, got {type(action_space)}. CEMSolver may not work as expected.") + + @property + def n_envs(self) -> int: + """Number of parallel environments.""" + return self._n_envs + + @property + def action_dim(self) -> int: + """Flattened action dimension including action_block grouping.""" + return self._action_dim * self._config.action_block + + @property + def horizon(self) -> int: + """Planning horizon in timesteps.""" + return self._config.horizon + + def __call__(self, *args: Any, **kwargs: Any) -> dict: + """Make solver callable, forwarding to solve().""" + return self.solve(*args, **kwargs) + + def init_action_distrib( + self, actions: torch.Tensor | None = None + ) -> tuple[torch.Tensor, torch.Tensor]: + """Initialize the action distribution parameters (mean and variance).""" + var = self.var_scale * torch.ones([self.n_envs, self.horizon, self.action_dim]) + mean = torch.zeros([self.n_envs, 0, self.action_dim]) if actions is None else actions + + remaining = self.horizon - mean.shape[1] + if remaining > 0: + device = mean.device + new_mean = torch.zeros([self.n_envs, remaining, self.action_dim]) + mean = torch.cat([mean, new_mean], dim=1).to(device) + + return mean, var + + @torch.inference_mode() + def solve( + self, info_dict: dict, init_action: torch.Tensor | None = None + ) -> dict: + """Solve the planning problem using Cross Entropy Method.""" + start_time = time.time() + outputs = { + "costs": [], + "mean": [], # History of means + "var": [], # History of vars + } + + # -- initialize the action distribution globally + mean, var = self.init_action_distrib(init_action) + mean = mean.to(self.device) + var = var.to(self.device) + + total_envs = self.n_envs + + # --- Iterate over batches --- + for start_idx in range(0, total_envs, self.batch_size): + end_idx = min(start_idx + self.batch_size, total_envs) + current_bs = end_idx - start_idx + + # Slice Distribution Parameters for current batch + batch_mean = mean[start_idx:end_idx] + batch_var = var[start_idx:end_idx] + + # Expand Info Dict for current batch + expanded_infos = {} + for k, v in info_dict.items(): + # v is shape (n_envs, ...) + # Slice batch + v_batch = v[start_idx:end_idx] + if torch.is_tensor(v): + if v_batch.device != self.device: + v_batch = v_batch.to(self.device, non_blocking=True) + # Add sample dim: (batch, 1, ...) + v_batch = v_batch.unsqueeze(1) + # Expand: (batch, num_samples, ...) + v_batch = v_batch.expand(current_bs, self.num_samples, *v_batch.shape[2:]) + elif isinstance(v, np.ndarray): + v_batch = np.repeat(v_batch[:, None, ...], self.num_samples, axis=1) + expanded_infos[k] = v_batch + + # Optimization Loop + final_batch_cost = None + + for step in range(self.n_steps): + # Sample action sequences: (Batch, Num_Samples, Horizon, Dim) + candidates = torch.randn( + current_bs, + self.num_samples, + self.horizon, + self.action_dim, + generator=self.torch_gen, + device=self.device, + ) + + # Scale and shift: (Batch, N, H, D) * (Batch, 1, H, D) + (Batch, 1, H, D) + candidates = candidates * batch_var.unsqueeze(1) + batch_mean.unsqueeze(1) + + # Force the first sample to be the current mean + candidates[:, 0] = batch_mean + + current_info = expanded_infos.copy() + + # Evaluate candidates + costs = self.model.get_cost(current_info, candidates) + + assert isinstance(costs, torch.Tensor), f"Expected cost to be a torch.Tensor, got {type(costs)}" + assert costs.ndim == 2 and costs.shape[0] == current_bs and costs.shape[1] == self.num_samples, ( + f"Expected cost to be of shape ({current_bs}, {self.num_samples}), got {costs.shape}" + ) + + # Select Top-K + # topk_vals: (Batch, K), topk_inds: (Batch, K) + topk_vals, topk_inds = torch.topk(costs, k=self.topk, dim=1, largest=False) + + # Gather Top-K Candidates + # We need to select the specific candidates corresponding to topk_inds + batch_indices = torch.arange(current_bs, device=self.device).unsqueeze(1).expand(-1, self.topk) + + # Indexing: candidates[batch_idx, sample_idx] + # Result shape: (Batch, K, Horizon, Dim) + topk_candidates = candidates[batch_indices, topk_inds] + + # Update Mean and Variance based on Top-K + batch_mean = topk_candidates.mean(dim=1) + batch_var = topk_candidates.std(dim=1) + + # Update final cost for logging + # We average the cost of the top elites + final_batch_cost = topk_vals.mean(dim=1).cpu().tolist() + + # Write results back to global storage + mean[start_idx:end_idx] = batch_mean + var[start_idx:end_idx] = batch_var + + # Store history/metadata + outputs["costs"].extend(final_batch_cost) + + outputs["actions"] = mean.detach().cpu() + outputs["mean"] = [mean.detach().cpu()] + outputs["var"] = [var.detach().cpu()] + + print(f"CEM solve time: {time.time() - start_time:.4f} seconds") + return outputs diff --git a/tworoom_results.txt b/tworoom_results.txt index 1770acb..1c8685c 100644 --- a/tworoom_results.txt +++ b/tworoom_results.txt @@ -55,3 +55,174 @@ metrics: {'success_rate': 88.0, 'episode_successes': array([ True, False, True, True, True, True, True, True]), 'seeds': None} evaluation_time: 133.1857841014862 seconds inference_precision: fp32 + +==== CONFIG ==== +cache_dir: null +solver: + _target_: stable_worldmodel.solver.CEMSolver + model: ??? + batch_size: 1 + num_samples: 300 + var_scale: 1.0 + n_steps: 30 + topk: 30 + device: cuda + seed: ${seed} +world: + env_name: swm/TwoRoom-v1 + num_envs: ${eval.num_eval} + max_episode_steps: 100 + history_size: 1 + frame_skip: 1 +seed: 42 +policy: two-room/tworoom/lejepa +dataset: + stats: ${eval.dataset_name} + keys_to_cache: + - action + - proprio +plan_config: + horizon: 5 + receding_horizon: 5 + action_block: 5 +eval: + num_eval: 50 + goal_offset_steps: 25 + eval_budget: 50 + img_size: 224 + dataset_name: tworoom + callables: + - method: _set_state + args: + state: + value: proprio + - method: _set_goal_state + args: + goal_state: + value: goal_proprio +output: + filename: tworoom_results.txt + +==== RESULTS ==== +metrics: {'success_rate': 88.0, 'episode_successes': array([ True, False, True, False, True, True, True, True, False, + True, True, True, True, True, True, True, True, True, + True, True, True, False, True, True, True, True, True, + True, True, True, True, False, True, True, True, True, + True, True, False, True, True, True, True, True, True, + True, True, True, True, True]), 'seeds': None} +evaluation_time: 131.6325900554657 seconds +inference_precision: fp32 + +==== CONFIG ==== +cache_dir: null +solver: + _target_: stable_worldmodel.solver.CEMSolver + model: ??? + batch_size: 1 + num_samples: 300 + var_scale: 1.0 + n_steps: 30 + topk: 30 + device: cuda + seed: ${seed} +world: + env_name: swm/TwoRoom-v1 + num_envs: ${eval.num_eval} + max_episode_steps: 100 + history_size: 1 + frame_skip: 1 +seed: 42 +policy: two-room/tworoom/lejepa +dataset: + stats: ${eval.dataset_name} + keys_to_cache: + - action + - proprio +plan_config: + horizon: 5 + receding_horizon: 5 + action_block: 5 +eval: + num_eval: 50 + goal_offset_steps: 25 + eval_budget: 50 + img_size: 224 + dataset_name: tworoom + callables: + - method: _set_state + args: + state: + value: proprio + - method: _set_goal_state + args: + goal_state: + value: goal_proprio +output: + filename: tworoom_results.txt + +==== RESULTS ==== +metrics: {'success_rate': 88.0, 'episode_successes': array([ True, False, True, False, True, True, True, True, False, + True, True, True, True, True, True, True, True, True, + True, True, True, False, True, True, True, True, True, + True, True, True, True, False, True, True, True, True, + True, True, False, True, True, True, True, True, True, + True, True, True, True, True]), 'seeds': None} +evaluation_time: 119.98270344734192 seconds +inference_precision: fp32 + +==== CONFIG ==== +cache_dir: null +solver: + _target_: stable_worldmodel.solver.CEMSolver + model: ??? + batch_size: 1 + num_samples: 300 + var_scale: 1.0 + n_steps: 30 + topk: 30 + device: cuda + seed: ${seed} +world: + env_name: swm/TwoRoom-v1 + num_envs: ${eval.num_eval} + max_episode_steps: 100 + history_size: 1 + frame_skip: 1 +seed: 42 +policy: two-room/tworoom/lejepa +dataset: + stats: ${eval.dataset_name} + keys_to_cache: + - action + - proprio +plan_config: + horizon: 5 + receding_horizon: 5 + action_block: 5 +eval: + num_eval: 50 + goal_offset_steps: 25 + eval_budget: 50 + img_size: 224 + dataset_name: tworoom + callables: + - method: _set_state + args: + state: + value: proprio + - method: _set_goal_state + args: + goal_state: + value: goal_proprio +output: + filename: tworoom_results.txt + +==== RESULTS ==== +metrics: {'success_rate': 88.0, 'episode_successes': array([ True, False, True, False, True, True, True, True, False, + True, True, True, True, True, True, True, True, True, + True, True, True, False, True, True, True, True, True, + True, True, True, True, False, True, True, True, True, + True, True, False, True, True, True, True, True, True, + True, True, True, True, True]), 'seeds': None} +evaluation_time: 121.47896695137024 seconds +inference_precision: fp32