Optimize CEM input transfers before sample expansion

This commit is contained in:
qihuanye
2026-04-08 13:01:24 +00:00
parent fa1c15c896
commit 12ba4f4352
2 changed files with 372 additions and 0 deletions

View File

@@ -0,0 +1,201 @@
"""Cross Entropy Method solver for model-based planning."""
import time
from typing import Any
import gymnasium as gym
import numpy as np
import torch
from gymnasium.spaces import Box
from loguru import logger as logging
from .solver import Costable
class CEMSolver:
"""Cross Entropy Method solver for action optimization.
Args:
model: World model implementing the Costable protocol.
batch_size: Number of environments to process in parallel.
num_samples: Number of action candidates to sample per iteration.
var_scale: Initial variance scale for the action distribution.
n_steps: Number of CEM iterations.
topk: Number of elite samples to keep for distribution update.
device: Device for tensor computations.
seed: Random seed for reproducibility.
"""
def __init__(
self,
model: Costable,
batch_size: int = 1,
num_samples: int = 300,
var_scale: float = 1,
n_steps: int = 30,
topk: int = 30,
device: str | torch.device = "cpu",
seed: int = 1234,
) -> None:
self.model = model
self.batch_size = batch_size
self.var_scale = var_scale
self.num_samples = num_samples
self.n_steps = n_steps
self.topk = topk
self.device = device
self.torch_gen = torch.Generator(device=device).manual_seed(seed)
def configure(self, *, action_space: gym.Space, n_envs: int, config: Any) -> None:
"""Configure the solver with environment specifications."""
self._action_space = action_space
self._n_envs = n_envs
self._config = config
self._action_dim = int(np.prod(action_space.shape[1:]))
self._configured = True
if not isinstance(action_space, Box):
logging.warning(f"Action space is discrete, got {type(action_space)}. CEMSolver may not work as expected.")
@property
def n_envs(self) -> int:
"""Number of parallel environments."""
return self._n_envs
@property
def action_dim(self) -> int:
"""Flattened action dimension including action_block grouping."""
return self._action_dim * self._config.action_block
@property
def horizon(self) -> int:
"""Planning horizon in timesteps."""
return self._config.horizon
def __call__(self, *args: Any, **kwargs: Any) -> dict:
"""Make solver callable, forwarding to solve()."""
return self.solve(*args, **kwargs)
def init_action_distrib(
self, actions: torch.Tensor | None = None
) -> tuple[torch.Tensor, torch.Tensor]:
"""Initialize the action distribution parameters (mean and variance)."""
var = self.var_scale * torch.ones([self.n_envs, self.horizon, self.action_dim])
mean = torch.zeros([self.n_envs, 0, self.action_dim]) if actions is None else actions
remaining = self.horizon - mean.shape[1]
if remaining > 0:
device = mean.device
new_mean = torch.zeros([self.n_envs, remaining, self.action_dim])
mean = torch.cat([mean, new_mean], dim=1).to(device)
return mean, var
@torch.inference_mode()
def solve(
self, info_dict: dict, init_action: torch.Tensor | None = None
) -> dict:
"""Solve the planning problem using Cross Entropy Method."""
start_time = time.time()
outputs = {
"costs": [],
"mean": [], # History of means
"var": [], # History of vars
}
# -- initialize the action distribution globally
mean, var = self.init_action_distrib(init_action)
mean = mean.to(self.device)
var = var.to(self.device)
total_envs = self.n_envs
# --- Iterate over batches ---
for start_idx in range(0, total_envs, self.batch_size):
end_idx = min(start_idx + self.batch_size, total_envs)
current_bs = end_idx - start_idx
# Slice Distribution Parameters for current batch
batch_mean = mean[start_idx:end_idx]
batch_var = var[start_idx:end_idx]
# Expand Info Dict for current batch
expanded_infos = {}
for k, v in info_dict.items():
# v is shape (n_envs, ...)
# Slice batch
v_batch = v[start_idx:end_idx]
if torch.is_tensor(v):
if v_batch.device != self.device:
v_batch = v_batch.to(self.device, non_blocking=True)
# Add sample dim: (batch, 1, ...)
v_batch = v_batch.unsqueeze(1)
# Expand: (batch, num_samples, ...)
v_batch = v_batch.expand(current_bs, self.num_samples, *v_batch.shape[2:])
elif isinstance(v, np.ndarray):
v_batch = np.repeat(v_batch[:, None, ...], self.num_samples, axis=1)
expanded_infos[k] = v_batch
# Optimization Loop
final_batch_cost = None
for step in range(self.n_steps):
# Sample action sequences: (Batch, Num_Samples, Horizon, Dim)
candidates = torch.randn(
current_bs,
self.num_samples,
self.horizon,
self.action_dim,
generator=self.torch_gen,
device=self.device,
)
# Scale and shift: (Batch, N, H, D) * (Batch, 1, H, D) + (Batch, 1, H, D)
candidates = candidates * batch_var.unsqueeze(1) + batch_mean.unsqueeze(1)
# Force the first sample to be the current mean
candidates[:, 0] = batch_mean
current_info = expanded_infos.copy()
# Evaluate candidates
costs = self.model.get_cost(current_info, candidates)
assert isinstance(costs, torch.Tensor), f"Expected cost to be a torch.Tensor, got {type(costs)}"
assert costs.ndim == 2 and costs.shape[0] == current_bs and costs.shape[1] == self.num_samples, (
f"Expected cost to be of shape ({current_bs}, {self.num_samples}), got {costs.shape}"
)
# Select Top-K
# topk_vals: (Batch, K), topk_inds: (Batch, K)
topk_vals, topk_inds = torch.topk(costs, k=self.topk, dim=1, largest=False)
# Gather Top-K Candidates
# We need to select the specific candidates corresponding to topk_inds
batch_indices = torch.arange(current_bs, device=self.device).unsqueeze(1).expand(-1, self.topk)
# Indexing: candidates[batch_idx, sample_idx]
# Result shape: (Batch, K, Horizon, Dim)
topk_candidates = candidates[batch_indices, topk_inds]
# Update Mean and Variance based on Top-K
batch_mean = topk_candidates.mean(dim=1)
batch_var = topk_candidates.std(dim=1)
# Update final cost for logging
# We average the cost of the top elites
final_batch_cost = topk_vals.mean(dim=1).cpu().tolist()
# Write results back to global storage
mean[start_idx:end_idx] = batch_mean
var[start_idx:end_idx] = batch_var
# Store history/metadata
outputs["costs"].extend(final_batch_cost)
outputs["actions"] = mean.detach().cpu()
outputs["mean"] = [mean.detach().cpu()]
outputs["var"] = [var.detach().cpu()]
print(f"CEM solve time: {time.time() - start_time:.4f} seconds")
return outputs

View File

@@ -55,3 +55,174 @@ metrics: {'success_rate': 88.0, 'episode_successes': array([ True, False, True,
True, True, True, True, True]), 'seeds': None}
evaluation_time: 133.1857841014862 seconds
inference_precision: fp32
==== CONFIG ====
cache_dir: null
solver:
_target_: stable_worldmodel.solver.CEMSolver
model: ???
batch_size: 1
num_samples: 300
var_scale: 1.0
n_steps: 30
topk: 30
device: cuda
seed: ${seed}
world:
env_name: swm/TwoRoom-v1
num_envs: ${eval.num_eval}
max_episode_steps: 100
history_size: 1
frame_skip: 1
seed: 42
policy: two-room/tworoom/lejepa
dataset:
stats: ${eval.dataset_name}
keys_to_cache:
- action
- proprio
plan_config:
horizon: 5
receding_horizon: 5
action_block: 5
eval:
num_eval: 50
goal_offset_steps: 25
eval_budget: 50
img_size: 224
dataset_name: tworoom
callables:
- method: _set_state
args:
state:
value: proprio
- method: _set_goal_state
args:
goal_state:
value: goal_proprio
output:
filename: tworoom_results.txt
==== RESULTS ====
metrics: {'success_rate': 88.0, 'episode_successes': array([ True, False, True, False, True, True, True, True, False,
True, True, True, True, True, True, True, True, True,
True, True, True, False, True, True, True, True, True,
True, True, True, True, False, True, True, True, True,
True, True, False, True, True, True, True, True, True,
True, True, True, True, True]), 'seeds': None}
evaluation_time: 131.6325900554657 seconds
inference_precision: fp32
==== CONFIG ====
cache_dir: null
solver:
_target_: stable_worldmodel.solver.CEMSolver
model: ???
batch_size: 1
num_samples: 300
var_scale: 1.0
n_steps: 30
topk: 30
device: cuda
seed: ${seed}
world:
env_name: swm/TwoRoom-v1
num_envs: ${eval.num_eval}
max_episode_steps: 100
history_size: 1
frame_skip: 1
seed: 42
policy: two-room/tworoom/lejepa
dataset:
stats: ${eval.dataset_name}
keys_to_cache:
- action
- proprio
plan_config:
horizon: 5
receding_horizon: 5
action_block: 5
eval:
num_eval: 50
goal_offset_steps: 25
eval_budget: 50
img_size: 224
dataset_name: tworoom
callables:
- method: _set_state
args:
state:
value: proprio
- method: _set_goal_state
args:
goal_state:
value: goal_proprio
output:
filename: tworoom_results.txt
==== RESULTS ====
metrics: {'success_rate': 88.0, 'episode_successes': array([ True, False, True, False, True, True, True, True, False,
True, True, True, True, True, True, True, True, True,
True, True, True, False, True, True, True, True, True,
True, True, True, True, False, True, True, True, True,
True, True, False, True, True, True, True, True, True,
True, True, True, True, True]), 'seeds': None}
evaluation_time: 119.98270344734192 seconds
inference_precision: fp32
==== CONFIG ====
cache_dir: null
solver:
_target_: stable_worldmodel.solver.CEMSolver
model: ???
batch_size: 1
num_samples: 300
var_scale: 1.0
n_steps: 30
topk: 30
device: cuda
seed: ${seed}
world:
env_name: swm/TwoRoom-v1
num_envs: ${eval.num_eval}
max_episode_steps: 100
history_size: 1
frame_skip: 1
seed: 42
policy: two-room/tworoom/lejepa
dataset:
stats: ${eval.dataset_name}
keys_to_cache:
- action
- proprio
plan_config:
horizon: 5
receding_horizon: 5
action_block: 5
eval:
num_eval: 50
goal_offset_steps: 25
eval_budget: 50
img_size: 224
dataset_name: tworoom
callables:
- method: _set_state
args:
state:
value: proprio
- method: _set_goal_state
args:
goal_state:
value: goal_proprio
output:
filename: tworoom_results.txt
==== RESULTS ====
metrics: {'success_rate': 88.0, 'episode_successes': array([ True, False, True, False, True, True, True, True, False,
True, True, True, True, True, True, True, True, True,
True, True, True, False, True, True, True, True, True,
True, True, True, True, False, True, True, True, True,
True, True, False, True, True, True, True, True, True,
True, True, True, True, True]), 'seeds': None}
evaluation_time: 121.47896695137024 seconds
inference_precision: fp32