"""Cross Entropy Method solver for model-based planning.""" import time from typing import Any import gymnasium as gym import numpy as np import torch from gymnasium.spaces import Box from loguru import logger as logging from .solver import Costable class CEMSolver: """Cross Entropy Method solver for action optimization. Args: model: World model implementing the Costable protocol. batch_size: Number of environments to process in parallel. num_samples: Number of action candidates to sample per iteration. var_scale: Initial variance scale for the action distribution. n_steps: Number of CEM iterations. topk: Number of elite samples to keep for distribution update. device: Device for tensor computations. seed: Random seed for reproducibility. """ def __init__( self, model: Costable, batch_size: int = 1, num_samples: int = 300, var_scale: float = 1, n_steps: int = 30, topk: int = 30, device: str | torch.device = "cpu", seed: int = 1234, ) -> None: self.model = model self.batch_size = batch_size self.var_scale = var_scale self.num_samples = num_samples self.n_steps = n_steps self.topk = topk self.device = device self.torch_gen = torch.Generator(device=device).manual_seed(seed) def configure(self, *, action_space: gym.Space, n_envs: int, config: Any) -> None: """Configure the solver with environment specifications.""" self._action_space = action_space self._n_envs = n_envs self._config = config self._action_dim = int(np.prod(action_space.shape[1:])) self._configured = True if not isinstance(action_space, Box): logging.warning(f"Action space is discrete, got {type(action_space)}. CEMSolver may not work as expected.") @property def n_envs(self) -> int: """Number of parallel environments.""" return self._n_envs @property def action_dim(self) -> int: """Flattened action dimension including action_block grouping.""" return self._action_dim * self._config.action_block @property def horizon(self) -> int: """Planning horizon in timesteps.""" return self._config.horizon def __call__(self, *args: Any, **kwargs: Any) -> dict: """Make solver callable, forwarding to solve().""" return self.solve(*args, **kwargs) def init_action_distrib( self, actions: torch.Tensor | None = None ) -> tuple[torch.Tensor, torch.Tensor]: """Initialize the action distribution parameters (mean and variance).""" device = torch.device(self.device) var = self.var_scale * torch.ones( [self.n_envs, self.horizon, self.action_dim], device=device, ) mean = ( torch.zeros([self.n_envs, 0, self.action_dim], device=device) if actions is None else actions ) remaining = self.horizon - mean.shape[1] if remaining > 0: new_mean = torch.zeros( [self.n_envs, remaining, self.action_dim], device=mean.device, ) mean = torch.cat([mean, new_mean], dim=1) return mean, var @torch.inference_mode() def solve( self, info_dict: dict, init_action: torch.Tensor | None = None ) -> dict: """Solve the planning problem using Cross Entropy Method.""" start_time = time.time() outputs = { "costs": [], "mean": [], # History of means "var": [], # History of vars } # -- initialize the action distribution globally mean, var = self.init_action_distrib(init_action) if mean.device != torch.device(self.device): mean = mean.to(self.device, non_blocking=True) if var.device != torch.device(self.device): var = var.to(self.device, non_blocking=True) total_envs = self.n_envs # --- Iterate over batches --- for start_idx in range(0, total_envs, self.batch_size): end_idx = min(start_idx + self.batch_size, total_envs) current_bs = end_idx - start_idx # Slice Distribution Parameters for current batch batch_mean = mean[start_idx:end_idx] batch_var = var[start_idx:end_idx] # Expand Info Dict for current batch expanded_infos = {} for k, v in info_dict.items(): # v is shape (n_envs, ...) # Slice batch v_batch = v[start_idx:end_idx] if torch.is_tensor(v): if v_batch.device != self.device: v_batch = v_batch.to(self.device, non_blocking=True) # Add sample dim: (batch, 1, ...) v_batch = v_batch.unsqueeze(1) # Expand: (batch, num_samples, ...) v_batch = v_batch.expand(current_bs, self.num_samples, *v_batch.shape[2:]) elif isinstance(v, np.ndarray): v_batch = np.repeat(v_batch[:, None, ...], self.num_samples, axis=1) expanded_infos[k] = v_batch # Optimization Loop final_batch_cost = None batch_indices = torch.arange(current_bs, device=self.device).unsqueeze(1) for step in range(self.n_steps): # Sample action sequences: (Batch, Num_Samples, Horizon, Dim) candidates = torch.randn( current_bs, self.num_samples, self.horizon, self.action_dim, generator=self.torch_gen, device=self.device, ) # Scale and shift: (Batch, N, H, D) * (Batch, 1, H, D) + (Batch, 1, H, D) candidates = candidates * batch_var.unsqueeze(1) + batch_mean.unsqueeze(1) # Force the first sample to be the current mean candidates[:, 0] = batch_mean current_info = expanded_infos.copy() # Evaluate candidates costs = self.model.get_cost(current_info, candidates) assert isinstance(costs, torch.Tensor), f"Expected cost to be a torch.Tensor, got {type(costs)}" assert costs.ndim == 2 and costs.shape[0] == current_bs and costs.shape[1] == self.num_samples, ( f"Expected cost to be of shape ({current_bs}, {self.num_samples}), got {costs.shape}" ) # Select Top-K # topk_vals: (Batch, K), topk_inds: (Batch, K) topk_vals, topk_inds = torch.topk(costs, k=self.topk, dim=1, largest=False) # Gather Top-K Candidates # We need to select the specific candidates corresponding to topk_inds # Indexing: candidates[batch_idx, sample_idx] # Result shape: (Batch, K, Horizon, Dim) topk_candidates = candidates[batch_indices, topk_inds] # Update Mean and Variance based on Top-K batch_mean = topk_candidates.mean(dim=1) batch_var = topk_candidates.std(dim=1) # Update final cost for logging # We average the cost of the top elites final_batch_cost = topk_vals.mean(dim=1).detach() # Write results back to global storage mean[start_idx:end_idx] = batch_mean var[start_idx:end_idx] = batch_var # Store history/metadata outputs["costs"].append(final_batch_cost) outputs["costs"] = torch.cat(outputs["costs"]).cpu().tolist() outputs["actions"] = mean.detach() outputs["mean"] = [mean.detach()] outputs["var"] = [var.detach()] print(f"CEM solve time: {time.time() - start_time:.4f} seconds") return outputs