Files
lewm/.venv/lib/python3.10/site-packages/stable_worldmodel/solver/cem.py
qihuanye 25e4ddb628 继续做了通用性能优化,重点从 jepa.py 热路径转到实际的 stable_worldmodel
solver/policy 边界:去掉 CEM 每轮 cpu().tolist() 和结果过早回 CPU,把
  plan/warm-start 保持在 GPU,只在 env.step 前最后一步转成 numpy,同时补
  了输入张量的 contiguous 处理;
2026-04-09 12:33:50 +00:00

214 lines
7.9 KiB
Python

"""Cross Entropy Method solver for model-based planning."""
import time
from typing import Any
import gymnasium as gym
import numpy as np
import torch
from gymnasium.spaces import Box
from loguru import logger as logging
from .solver import Costable
class CEMSolver:
"""Cross Entropy Method solver for action optimization.
Args:
model: World model implementing the Costable protocol.
batch_size: Number of environments to process in parallel.
num_samples: Number of action candidates to sample per iteration.
var_scale: Initial variance scale for the action distribution.
n_steps: Number of CEM iterations.
topk: Number of elite samples to keep for distribution update.
device: Device for tensor computations.
seed: Random seed for reproducibility.
"""
def __init__(
self,
model: Costable,
batch_size: int = 1,
num_samples: int = 300,
var_scale: float = 1,
n_steps: int = 30,
topk: int = 30,
device: str | torch.device = "cpu",
seed: int = 1234,
) -> None:
self.model = model
self.batch_size = batch_size
self.var_scale = var_scale
self.num_samples = num_samples
self.n_steps = n_steps
self.topk = topk
self.device = device
self.torch_gen = torch.Generator(device=device).manual_seed(seed)
def configure(self, *, action_space: gym.Space, n_envs: int, config: Any) -> None:
"""Configure the solver with environment specifications."""
self._action_space = action_space
self._n_envs = n_envs
self._config = config
self._action_dim = int(np.prod(action_space.shape[1:]))
self._configured = True
if not isinstance(action_space, Box):
logging.warning(f"Action space is discrete, got {type(action_space)}. CEMSolver may not work as expected.")
@property
def n_envs(self) -> int:
"""Number of parallel environments."""
return self._n_envs
@property
def action_dim(self) -> int:
"""Flattened action dimension including action_block grouping."""
return self._action_dim * self._config.action_block
@property
def horizon(self) -> int:
"""Planning horizon in timesteps."""
return self._config.horizon
def __call__(self, *args: Any, **kwargs: Any) -> dict:
"""Make solver callable, forwarding to solve()."""
return self.solve(*args, **kwargs)
def init_action_distrib(
self, actions: torch.Tensor | None = None
) -> tuple[torch.Tensor, torch.Tensor]:
"""Initialize the action distribution parameters (mean and variance)."""
device = torch.device(self.device)
var = self.var_scale * torch.ones(
[self.n_envs, self.horizon, self.action_dim],
device=device,
)
mean = (
torch.zeros([self.n_envs, 0, self.action_dim], device=device)
if actions is None
else actions
)
remaining = self.horizon - mean.shape[1]
if remaining > 0:
new_mean = torch.zeros(
[self.n_envs, remaining, self.action_dim],
device=mean.device,
)
mean = torch.cat([mean, new_mean], dim=1)
return mean, var
@torch.inference_mode()
def solve(
self, info_dict: dict, init_action: torch.Tensor | None = None
) -> dict:
"""Solve the planning problem using Cross Entropy Method."""
start_time = time.time()
outputs = {
"costs": [],
"mean": [], # History of means
"var": [], # History of vars
}
# -- initialize the action distribution globally
mean, var = self.init_action_distrib(init_action)
if mean.device != torch.device(self.device):
mean = mean.to(self.device, non_blocking=True)
if var.device != torch.device(self.device):
var = var.to(self.device, non_blocking=True)
total_envs = self.n_envs
# --- Iterate over batches ---
for start_idx in range(0, total_envs, self.batch_size):
end_idx = min(start_idx + self.batch_size, total_envs)
current_bs = end_idx - start_idx
# Slice Distribution Parameters for current batch
batch_mean = mean[start_idx:end_idx]
batch_var = var[start_idx:end_idx]
# Expand Info Dict for current batch
expanded_infos = {}
for k, v in info_dict.items():
# v is shape (n_envs, ...)
# Slice batch
v_batch = v[start_idx:end_idx]
if torch.is_tensor(v):
if v_batch.device != self.device:
v_batch = v_batch.to(self.device, non_blocking=True)
# Add sample dim: (batch, 1, ...)
v_batch = v_batch.unsqueeze(1)
# Expand: (batch, num_samples, ...)
v_batch = v_batch.expand(current_bs, self.num_samples, *v_batch.shape[2:])
elif isinstance(v, np.ndarray):
v_batch = np.repeat(v_batch[:, None, ...], self.num_samples, axis=1)
expanded_infos[k] = v_batch
# Optimization Loop
final_batch_cost = None
batch_indices = torch.arange(current_bs, device=self.device).unsqueeze(1)
for step in range(self.n_steps):
# Sample action sequences: (Batch, Num_Samples, Horizon, Dim)
candidates = torch.randn(
current_bs,
self.num_samples,
self.horizon,
self.action_dim,
generator=self.torch_gen,
device=self.device,
)
# Scale and shift: (Batch, N, H, D) * (Batch, 1, H, D) + (Batch, 1, H, D)
candidates = candidates * batch_var.unsqueeze(1) + batch_mean.unsqueeze(1)
# Force the first sample to be the current mean
candidates[:, 0] = batch_mean
current_info = expanded_infos.copy()
# Evaluate candidates
costs = self.model.get_cost(current_info, candidates)
assert isinstance(costs, torch.Tensor), f"Expected cost to be a torch.Tensor, got {type(costs)}"
assert costs.ndim == 2 and costs.shape[0] == current_bs and costs.shape[1] == self.num_samples, (
f"Expected cost to be of shape ({current_bs}, {self.num_samples}), got {costs.shape}"
)
# Select Top-K
# topk_vals: (Batch, K), topk_inds: (Batch, K)
topk_vals, topk_inds = torch.topk(costs, k=self.topk, dim=1, largest=False)
# Gather Top-K Candidates
# We need to select the specific candidates corresponding to topk_inds
# Indexing: candidates[batch_idx, sample_idx]
# Result shape: (Batch, K, Horizon, Dim)
topk_candidates = candidates[batch_indices, topk_inds]
# Update Mean and Variance based on Top-K
batch_mean = topk_candidates.mean(dim=1)
batch_var = topk_candidates.std(dim=1)
# Update final cost for logging
# We average the cost of the top elites
final_batch_cost = topk_vals.mean(dim=1).detach()
# Write results back to global storage
mean[start_idx:end_idx] = batch_mean
var[start_idx:end_idx] = batch_var
# Store history/metadata
outputs["costs"].append(final_batch_cost)
outputs["costs"] = torch.cat(outputs["costs"]).cpu().tolist()
outputs["actions"] = mean.detach()
outputs["mean"] = [mean.detach()]
outputs["var"] = [var.detach()]
print(f"CEM solve time: {time.time() - start_time:.4f} seconds")
return outputs