Optimize CEM input transfers before sample expansion
This commit is contained in:
@@ -0,0 +1,201 @@
|
||||
"""Cross Entropy Method solver for model-based planning."""
|
||||
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
import torch
|
||||
from gymnasium.spaces import Box
|
||||
from loguru import logger as logging
|
||||
|
||||
from .solver import Costable
|
||||
|
||||
|
||||
class CEMSolver:
|
||||
"""Cross Entropy Method solver for action optimization.
|
||||
|
||||
Args:
|
||||
model: World model implementing the Costable protocol.
|
||||
batch_size: Number of environments to process in parallel.
|
||||
num_samples: Number of action candidates to sample per iteration.
|
||||
var_scale: Initial variance scale for the action distribution.
|
||||
n_steps: Number of CEM iterations.
|
||||
topk: Number of elite samples to keep for distribution update.
|
||||
device: Device for tensor computations.
|
||||
seed: Random seed for reproducibility.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Costable,
|
||||
batch_size: int = 1,
|
||||
num_samples: int = 300,
|
||||
var_scale: float = 1,
|
||||
n_steps: int = 30,
|
||||
topk: int = 30,
|
||||
device: str | torch.device = "cpu",
|
||||
seed: int = 1234,
|
||||
) -> None:
|
||||
self.model = model
|
||||
self.batch_size = batch_size
|
||||
self.var_scale = var_scale
|
||||
self.num_samples = num_samples
|
||||
self.n_steps = n_steps
|
||||
self.topk = topk
|
||||
self.device = device
|
||||
self.torch_gen = torch.Generator(device=device).manual_seed(seed)
|
||||
|
||||
def configure(self, *, action_space: gym.Space, n_envs: int, config: Any) -> None:
|
||||
"""Configure the solver with environment specifications."""
|
||||
self._action_space = action_space
|
||||
self._n_envs = n_envs
|
||||
self._config = config
|
||||
self._action_dim = int(np.prod(action_space.shape[1:]))
|
||||
self._configured = True
|
||||
|
||||
if not isinstance(action_space, Box):
|
||||
logging.warning(f"Action space is discrete, got {type(action_space)}. CEMSolver may not work as expected.")
|
||||
|
||||
@property
|
||||
def n_envs(self) -> int:
|
||||
"""Number of parallel environments."""
|
||||
return self._n_envs
|
||||
|
||||
@property
|
||||
def action_dim(self) -> int:
|
||||
"""Flattened action dimension including action_block grouping."""
|
||||
return self._action_dim * self._config.action_block
|
||||
|
||||
@property
|
||||
def horizon(self) -> int:
|
||||
"""Planning horizon in timesteps."""
|
||||
return self._config.horizon
|
||||
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> dict:
|
||||
"""Make solver callable, forwarding to solve()."""
|
||||
return self.solve(*args, **kwargs)
|
||||
|
||||
def init_action_distrib(
|
||||
self, actions: torch.Tensor | None = None
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Initialize the action distribution parameters (mean and variance)."""
|
||||
var = self.var_scale * torch.ones([self.n_envs, self.horizon, self.action_dim])
|
||||
mean = torch.zeros([self.n_envs, 0, self.action_dim]) if actions is None else actions
|
||||
|
||||
remaining = self.horizon - mean.shape[1]
|
||||
if remaining > 0:
|
||||
device = mean.device
|
||||
new_mean = torch.zeros([self.n_envs, remaining, self.action_dim])
|
||||
mean = torch.cat([mean, new_mean], dim=1).to(device)
|
||||
|
||||
return mean, var
|
||||
|
||||
@torch.inference_mode()
|
||||
def solve(
|
||||
self, info_dict: dict, init_action: torch.Tensor | None = None
|
||||
) -> dict:
|
||||
"""Solve the planning problem using Cross Entropy Method."""
|
||||
start_time = time.time()
|
||||
outputs = {
|
||||
"costs": [],
|
||||
"mean": [], # History of means
|
||||
"var": [], # History of vars
|
||||
}
|
||||
|
||||
# -- initialize the action distribution globally
|
||||
mean, var = self.init_action_distrib(init_action)
|
||||
mean = mean.to(self.device)
|
||||
var = var.to(self.device)
|
||||
|
||||
total_envs = self.n_envs
|
||||
|
||||
# --- Iterate over batches ---
|
||||
for start_idx in range(0, total_envs, self.batch_size):
|
||||
end_idx = min(start_idx + self.batch_size, total_envs)
|
||||
current_bs = end_idx - start_idx
|
||||
|
||||
# Slice Distribution Parameters for current batch
|
||||
batch_mean = mean[start_idx:end_idx]
|
||||
batch_var = var[start_idx:end_idx]
|
||||
|
||||
# Expand Info Dict for current batch
|
||||
expanded_infos = {}
|
||||
for k, v in info_dict.items():
|
||||
# v is shape (n_envs, ...)
|
||||
# Slice batch
|
||||
v_batch = v[start_idx:end_idx]
|
||||
if torch.is_tensor(v):
|
||||
if v_batch.device != self.device:
|
||||
v_batch = v_batch.to(self.device, non_blocking=True)
|
||||
# Add sample dim: (batch, 1, ...)
|
||||
v_batch = v_batch.unsqueeze(1)
|
||||
# Expand: (batch, num_samples, ...)
|
||||
v_batch = v_batch.expand(current_bs, self.num_samples, *v_batch.shape[2:])
|
||||
elif isinstance(v, np.ndarray):
|
||||
v_batch = np.repeat(v_batch[:, None, ...], self.num_samples, axis=1)
|
||||
expanded_infos[k] = v_batch
|
||||
|
||||
# Optimization Loop
|
||||
final_batch_cost = None
|
||||
|
||||
for step in range(self.n_steps):
|
||||
# Sample action sequences: (Batch, Num_Samples, Horizon, Dim)
|
||||
candidates = torch.randn(
|
||||
current_bs,
|
||||
self.num_samples,
|
||||
self.horizon,
|
||||
self.action_dim,
|
||||
generator=self.torch_gen,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
# Scale and shift: (Batch, N, H, D) * (Batch, 1, H, D) + (Batch, 1, H, D)
|
||||
candidates = candidates * batch_var.unsqueeze(1) + batch_mean.unsqueeze(1)
|
||||
|
||||
# Force the first sample to be the current mean
|
||||
candidates[:, 0] = batch_mean
|
||||
|
||||
current_info = expanded_infos.copy()
|
||||
|
||||
# Evaluate candidates
|
||||
costs = self.model.get_cost(current_info, candidates)
|
||||
|
||||
assert isinstance(costs, torch.Tensor), f"Expected cost to be a torch.Tensor, got {type(costs)}"
|
||||
assert costs.ndim == 2 and costs.shape[0] == current_bs and costs.shape[1] == self.num_samples, (
|
||||
f"Expected cost to be of shape ({current_bs}, {self.num_samples}), got {costs.shape}"
|
||||
)
|
||||
|
||||
# Select Top-K
|
||||
# topk_vals: (Batch, K), topk_inds: (Batch, K)
|
||||
topk_vals, topk_inds = torch.topk(costs, k=self.topk, dim=1, largest=False)
|
||||
|
||||
# Gather Top-K Candidates
|
||||
# We need to select the specific candidates corresponding to topk_inds
|
||||
batch_indices = torch.arange(current_bs, device=self.device).unsqueeze(1).expand(-1, self.topk)
|
||||
|
||||
# Indexing: candidates[batch_idx, sample_idx]
|
||||
# Result shape: (Batch, K, Horizon, Dim)
|
||||
topk_candidates = candidates[batch_indices, topk_inds]
|
||||
|
||||
# Update Mean and Variance based on Top-K
|
||||
batch_mean = topk_candidates.mean(dim=1)
|
||||
batch_var = topk_candidates.std(dim=1)
|
||||
|
||||
# Update final cost for logging
|
||||
# We average the cost of the top elites
|
||||
final_batch_cost = topk_vals.mean(dim=1).cpu().tolist()
|
||||
|
||||
# Write results back to global storage
|
||||
mean[start_idx:end_idx] = batch_mean
|
||||
var[start_idx:end_idx] = batch_var
|
||||
|
||||
# Store history/metadata
|
||||
outputs["costs"].extend(final_batch_cost)
|
||||
|
||||
outputs["actions"] = mean.detach().cpu()
|
||||
outputs["mean"] = [mean.detach().cpu()]
|
||||
outputs["var"] = [var.detach().cpu()]
|
||||
|
||||
print(f"CEM solve time: {time.time() - start_time:.4f} seconds")
|
||||
return outputs
|
||||
@@ -55,3 +55,174 @@ metrics: {'success_rate': 88.0, 'episode_successes': array([ True, False, True,
|
||||
True, True, True, True, True]), 'seeds': None}
|
||||
evaluation_time: 133.1857841014862 seconds
|
||||
inference_precision: fp32
|
||||
|
||||
==== CONFIG ====
|
||||
cache_dir: null
|
||||
solver:
|
||||
_target_: stable_worldmodel.solver.CEMSolver
|
||||
model: ???
|
||||
batch_size: 1
|
||||
num_samples: 300
|
||||
var_scale: 1.0
|
||||
n_steps: 30
|
||||
topk: 30
|
||||
device: cuda
|
||||
seed: ${seed}
|
||||
world:
|
||||
env_name: swm/TwoRoom-v1
|
||||
num_envs: ${eval.num_eval}
|
||||
max_episode_steps: 100
|
||||
history_size: 1
|
||||
frame_skip: 1
|
||||
seed: 42
|
||||
policy: two-room/tworoom/lejepa
|
||||
dataset:
|
||||
stats: ${eval.dataset_name}
|
||||
keys_to_cache:
|
||||
- action
|
||||
- proprio
|
||||
plan_config:
|
||||
horizon: 5
|
||||
receding_horizon: 5
|
||||
action_block: 5
|
||||
eval:
|
||||
num_eval: 50
|
||||
goal_offset_steps: 25
|
||||
eval_budget: 50
|
||||
img_size: 224
|
||||
dataset_name: tworoom
|
||||
callables:
|
||||
- method: _set_state
|
||||
args:
|
||||
state:
|
||||
value: proprio
|
||||
- method: _set_goal_state
|
||||
args:
|
||||
goal_state:
|
||||
value: goal_proprio
|
||||
output:
|
||||
filename: tworoom_results.txt
|
||||
|
||||
==== RESULTS ====
|
||||
metrics: {'success_rate': 88.0, 'episode_successes': array([ True, False, True, False, True, True, True, True, False,
|
||||
True, True, True, True, True, True, True, True, True,
|
||||
True, True, True, False, True, True, True, True, True,
|
||||
True, True, True, True, False, True, True, True, True,
|
||||
True, True, False, True, True, True, True, True, True,
|
||||
True, True, True, True, True]), 'seeds': None}
|
||||
evaluation_time: 131.6325900554657 seconds
|
||||
inference_precision: fp32
|
||||
|
||||
==== CONFIG ====
|
||||
cache_dir: null
|
||||
solver:
|
||||
_target_: stable_worldmodel.solver.CEMSolver
|
||||
model: ???
|
||||
batch_size: 1
|
||||
num_samples: 300
|
||||
var_scale: 1.0
|
||||
n_steps: 30
|
||||
topk: 30
|
||||
device: cuda
|
||||
seed: ${seed}
|
||||
world:
|
||||
env_name: swm/TwoRoom-v1
|
||||
num_envs: ${eval.num_eval}
|
||||
max_episode_steps: 100
|
||||
history_size: 1
|
||||
frame_skip: 1
|
||||
seed: 42
|
||||
policy: two-room/tworoom/lejepa
|
||||
dataset:
|
||||
stats: ${eval.dataset_name}
|
||||
keys_to_cache:
|
||||
- action
|
||||
- proprio
|
||||
plan_config:
|
||||
horizon: 5
|
||||
receding_horizon: 5
|
||||
action_block: 5
|
||||
eval:
|
||||
num_eval: 50
|
||||
goal_offset_steps: 25
|
||||
eval_budget: 50
|
||||
img_size: 224
|
||||
dataset_name: tworoom
|
||||
callables:
|
||||
- method: _set_state
|
||||
args:
|
||||
state:
|
||||
value: proprio
|
||||
- method: _set_goal_state
|
||||
args:
|
||||
goal_state:
|
||||
value: goal_proprio
|
||||
output:
|
||||
filename: tworoom_results.txt
|
||||
|
||||
==== RESULTS ====
|
||||
metrics: {'success_rate': 88.0, 'episode_successes': array([ True, False, True, False, True, True, True, True, False,
|
||||
True, True, True, True, True, True, True, True, True,
|
||||
True, True, True, False, True, True, True, True, True,
|
||||
True, True, True, True, False, True, True, True, True,
|
||||
True, True, False, True, True, True, True, True, True,
|
||||
True, True, True, True, True]), 'seeds': None}
|
||||
evaluation_time: 119.98270344734192 seconds
|
||||
inference_precision: fp32
|
||||
|
||||
==== CONFIG ====
|
||||
cache_dir: null
|
||||
solver:
|
||||
_target_: stable_worldmodel.solver.CEMSolver
|
||||
model: ???
|
||||
batch_size: 1
|
||||
num_samples: 300
|
||||
var_scale: 1.0
|
||||
n_steps: 30
|
||||
topk: 30
|
||||
device: cuda
|
||||
seed: ${seed}
|
||||
world:
|
||||
env_name: swm/TwoRoom-v1
|
||||
num_envs: ${eval.num_eval}
|
||||
max_episode_steps: 100
|
||||
history_size: 1
|
||||
frame_skip: 1
|
||||
seed: 42
|
||||
policy: two-room/tworoom/lejepa
|
||||
dataset:
|
||||
stats: ${eval.dataset_name}
|
||||
keys_to_cache:
|
||||
- action
|
||||
- proprio
|
||||
plan_config:
|
||||
horizon: 5
|
||||
receding_horizon: 5
|
||||
action_block: 5
|
||||
eval:
|
||||
num_eval: 50
|
||||
goal_offset_steps: 25
|
||||
eval_budget: 50
|
||||
img_size: 224
|
||||
dataset_name: tworoom
|
||||
callables:
|
||||
- method: _set_state
|
||||
args:
|
||||
state:
|
||||
value: proprio
|
||||
- method: _set_goal_state
|
||||
args:
|
||||
goal_state:
|
||||
value: goal_proprio
|
||||
output:
|
||||
filename: tworoom_results.txt
|
||||
|
||||
==== RESULTS ====
|
||||
metrics: {'success_rate': 88.0, 'episode_successes': array([ True, False, True, False, True, True, True, True, False,
|
||||
True, True, True, True, True, True, True, True, True,
|
||||
True, True, True, False, True, True, True, True, True,
|
||||
True, True, True, True, False, True, True, True, True,
|
||||
True, True, False, True, True, True, True, True, True,
|
||||
True, True, True, True, True]), 'seeds': None}
|
||||
evaluation_time: 121.47896695137024 seconds
|
||||
inference_precision: fp32
|
||||
|
||||
Reference in New Issue
Block a user