Compare commits
32 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a639fdefca | ||
|
|
28f2fba0e8 | ||
|
|
113e591899 | ||
|
|
0164e21f48 | ||
|
|
02080e2564 | ||
|
|
d86aeb2df0 | ||
|
|
5e55727901 | ||
|
|
02c3cea3f9 | ||
|
|
f08f2b82f4 | ||
|
|
e84074d6d6 | ||
|
|
cf43af0729 | ||
|
|
4c3fdbcce6 | ||
|
|
75a5d86966 | ||
|
|
46cb2177bc | ||
|
|
8ba5bc8b0b | ||
|
|
e6f2b2b9d4 | ||
|
|
25e4ddb628 | ||
|
|
995cd8cfec | ||
|
|
cd03a0d5cb | ||
|
|
20ffb3492b | ||
|
|
96e17a13af | ||
|
|
006102d00c | ||
|
|
3a94829eac | ||
|
|
38be7d3bef | ||
|
|
f2750daace | ||
|
|
9e2407cdc4 | ||
|
|
0f85e39690 | ||
|
|
85795bd91d | ||
|
|
7c2e341d93 | ||
|
|
12ba4f4352 | ||
|
|
fa1c15c896 | ||
|
|
8b84251eb9 |
34
.gitignore
vendored
Normal file
34
.gitignore
vendored
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
.venv/*
|
||||||
|
!.venv/.gitignore
|
||||||
|
!.venv/lib/
|
||||||
|
.venv/lib/*
|
||||||
|
!.venv/lib/python3.10/
|
||||||
|
.venv/lib/python3.10/*
|
||||||
|
!.venv/lib/python3.10/site-packages/
|
||||||
|
.venv/lib/python3.10/site-packages/*
|
||||||
|
!.venv/lib/python3.10/site-packages/stable_worldmodel/
|
||||||
|
.venv/lib/python3.10/site-packages/stable_worldmodel/*
|
||||||
|
!.venv/lib/python3.10/site-packages/stable_worldmodel/solver/
|
||||||
|
!.venv/lib/python3.10/site-packages/stable_worldmodel/policy.py
|
||||||
|
!.venv/lib/python3.10/site-packages/stable_worldmodel/world.py
|
||||||
|
!.venv/lib/python3.10/site-packages/stable_worldmodel/solver/cem.py
|
||||||
|
!.venv/lib/python3.10/site-packages/stable_worldmodel/solver/gd.py
|
||||||
|
outputs/
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
|
||||||
|
.pytest_cache/
|
||||||
|
.mypy_cache/
|
||||||
|
.ruff_cache/
|
||||||
|
|
||||||
|
torch_profile/
|
||||||
|
trace.json
|
||||||
|
key_averages.txt
|
||||||
|
eval_tmp_*.npy
|
||||||
|
*.mp4
|
||||||
|
*.gif
|
||||||
|
|
||||||
|
.DS_Store
|
||||||
|
.idea/
|
||||||
|
.vscode/
|
||||||
|
*.log
|
||||||
0
.venv/.gitignore
vendored
Normal file
0
.venv/.gitignore
vendored
Normal file
628
.venv/lib/python3.10/site-packages/stable_worldmodel/policy.py
Normal file
628
.venv/lib/python3.10/site-packages/stable_worldmodel/policy.py
Normal file
@@ -0,0 +1,628 @@
|
|||||||
|
from collections import deque
|
||||||
|
from dataclasses import dataclass
|
||||||
|
import inspect
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Protocol
|
||||||
|
from collections.abc import Callable
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from loguru import logger as logging
|
||||||
|
|
||||||
|
import stable_worldmodel as swm
|
||||||
|
from stable_worldmodel.solver import Solver
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class PlanConfig:
|
||||||
|
"""Configuration for the MPC planning loop.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
horizon: Planning horizon in number of steps.
|
||||||
|
receding_horizon: Number of steps to execute before re-planning.
|
||||||
|
history_len: Number of past observations to consider.
|
||||||
|
action_block: Number of times each action is repeated (frameskip).
|
||||||
|
warm_start: Whether to use the previous plan to initialize the next one.
|
||||||
|
"""
|
||||||
|
|
||||||
|
horizon: int
|
||||||
|
receding_horizon: int
|
||||||
|
history_len: int = 1
|
||||||
|
action_block: int = 1
|
||||||
|
warm_start: bool = True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def plan_len(self) -> int:
|
||||||
|
"""Total plan length in environment steps."""
|
||||||
|
return self.horizon * self.action_block
|
||||||
|
|
||||||
|
|
||||||
|
class Transformable(Protocol):
|
||||||
|
"""Protocol for reversible data transformations (e.g., normalizers, scalers)."""
|
||||||
|
|
||||||
|
def transform(self, x: np.ndarray) -> np.ndarray: # pragma: no cover
|
||||||
|
"""Apply preprocessing to input data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input data as a numpy array.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Preprocessed data as a numpy array.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def inverse_transform(
|
||||||
|
self, x: np.ndarray
|
||||||
|
) -> np.ndarray: # pragma: no cover
|
||||||
|
"""Reverse the preprocessing transformation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Preprocessed data as a numpy array.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Original data as a numpy array.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class Actionable(Protocol):
|
||||||
|
"""Protocol for model action computation."""
|
||||||
|
|
||||||
|
def get_action(info) -> torch.Tensor: # pragma: no cover
|
||||||
|
"""Compute action from observation and goal"""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class BasePolicy:
|
||||||
|
"""Base class for agent policies.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
env: The environment the policy is associated with.
|
||||||
|
type: A string identifier for the policy type.
|
||||||
|
"""
|
||||||
|
|
||||||
|
env: Any
|
||||||
|
type: str
|
||||||
|
|
||||||
|
def __init__(self, **kwargs: Any) -> None:
|
||||||
|
"""Initialize the base policy.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
**kwargs: Additional configuration parameters.
|
||||||
|
"""
|
||||||
|
self.env = None
|
||||||
|
self.type = 'base'
|
||||||
|
for arg, value in kwargs.items():
|
||||||
|
setattr(self, arg, value)
|
||||||
|
|
||||||
|
def get_action(self, obs: Any, **kwargs: Any) -> np.ndarray:
|
||||||
|
"""Get action from the policy given the observation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
obs: The current observation from the environment.
|
||||||
|
**kwargs: Additional parameters for action selection.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Selected action as a numpy array.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NotImplementedError: If not implemented by a subclass.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def set_env(self, env: Any) -> None:
|
||||||
|
"""Associate this policy with an environment.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
env: The environment to associate.
|
||||||
|
"""
|
||||||
|
self.env = env
|
||||||
|
|
||||||
|
def _move_info_to_device(
|
||||||
|
self, info_dict: dict[str, Any], device: torch.device | str
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
target = torch.device(device)
|
||||||
|
for k, v in info_dict.items():
|
||||||
|
if torch.is_tensor(v):
|
||||||
|
if v.device != target:
|
||||||
|
v = v.to(target, non_blocking=True)
|
||||||
|
if not v.is_contiguous():
|
||||||
|
v = v.contiguous()
|
||||||
|
info_dict[k] = v
|
||||||
|
return info_dict
|
||||||
|
|
||||||
|
def _prepare_info_for_device(
|
||||||
|
self,
|
||||||
|
info_dict: dict,
|
||||||
|
device: torch.device | str | None = None,
|
||||||
|
) -> dict[str, torch.Tensor]:
|
||||||
|
if device is None:
|
||||||
|
return self._prepare_info(info_dict)
|
||||||
|
|
||||||
|
prepare_info = self._prepare_info
|
||||||
|
try:
|
||||||
|
signature = inspect.signature(prepare_info)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
return prepare_info(info_dict)
|
||||||
|
|
||||||
|
accepts_device = any(
|
||||||
|
param.kind == inspect.Parameter.VAR_KEYWORD
|
||||||
|
or (param.name == "device" and param.kind in {
|
||||||
|
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||||
|
inspect.Parameter.KEYWORD_ONLY,
|
||||||
|
})
|
||||||
|
for param in signature.parameters.values()
|
||||||
|
)
|
||||||
|
if accepts_device:
|
||||||
|
return prepare_info(info_dict, device=device)
|
||||||
|
return prepare_info(info_dict)
|
||||||
|
|
||||||
|
def _prepare_info(
|
||||||
|
self,
|
||||||
|
info_dict: dict,
|
||||||
|
device: torch.device | str | None = None,
|
||||||
|
) -> dict[str, torch.Tensor]:
|
||||||
|
"""Pre-process and transform observations.
|
||||||
|
|
||||||
|
Applies preprocessing (via `self.process`) and transformations (via `self.transform`)
|
||||||
|
to observation data. Used by subclasses like FeedForwardPolicy and WorldModelPolicy.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
info_dict: Raw observation dictionary from the environment.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dictionary of processed tensors.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If an expected numpy array is missing for processing.
|
||||||
|
"""
|
||||||
|
target_device = torch.device(device) if device is not None else None
|
||||||
|
for k, v in info_dict.items():
|
||||||
|
is_numpy = isinstance(v, (np.ndarray | np.generic))
|
||||||
|
|
||||||
|
if hasattr(self, 'process') and k in self.process:
|
||||||
|
if not is_numpy:
|
||||||
|
raise ValueError(
|
||||||
|
f"Expected numpy array for key '{k}' in process, got {type(v)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# flatten extra dimensions if needed
|
||||||
|
shape = v.shape
|
||||||
|
if len(shape) > 2:
|
||||||
|
v = v.reshape(-1, *shape[2:])
|
||||||
|
|
||||||
|
# process and reshape back
|
||||||
|
v = self.process[k].transform(v)
|
||||||
|
v = v.reshape(shape)
|
||||||
|
|
||||||
|
# collapse env and time dimensions for transform (e, t, ...) -> (e * t, ...)
|
||||||
|
# then restore after transform
|
||||||
|
if hasattr(self, 'transform') and k in self.transform:
|
||||||
|
shape = None
|
||||||
|
if is_numpy or torch.is_tensor(v):
|
||||||
|
if v.ndim > 2:
|
||||||
|
shape = v.shape
|
||||||
|
v = v.reshape(-1, *shape[2:])
|
||||||
|
if is_numpy:
|
||||||
|
if v.dtype.kind in 'USO':
|
||||||
|
raise ValueError(
|
||||||
|
f"Expected numeric numpy array for key '{k}', got dtype {v.dtype}"
|
||||||
|
)
|
||||||
|
v = torch.from_numpy(v)
|
||||||
|
is_numpy = False
|
||||||
|
moved_for_transform = False
|
||||||
|
if target_device is not None and target_device.type != "cpu":
|
||||||
|
if v.device != target_device:
|
||||||
|
v = v.to(target_device, non_blocking=True)
|
||||||
|
moved_for_transform = True
|
||||||
|
if k.startswith('pixels') or k.startswith('goal'):
|
||||||
|
# Vectorized image transform on the full batch.
|
||||||
|
v = v.permute(0, 3, 1, 2)
|
||||||
|
try:
|
||||||
|
v = self.transform[k](v)
|
||||||
|
except (NotImplementedError, RuntimeError):
|
||||||
|
if not moved_for_transform:
|
||||||
|
raise
|
||||||
|
v = self.transform[k](v.cpu())
|
||||||
|
|
||||||
|
if shape is not None:
|
||||||
|
v = v.reshape(*shape[:2], *v.shape[1:])
|
||||||
|
|
||||||
|
if is_numpy and v.dtype.kind not in 'USO':
|
||||||
|
v = torch.from_numpy(v)
|
||||||
|
|
||||||
|
if (
|
||||||
|
torch.cuda.is_available()
|
||||||
|
and torch.is_tensor(v)
|
||||||
|
and v.device.type == "cpu"
|
||||||
|
and not v.is_pinned()
|
||||||
|
):
|
||||||
|
v = v.pin_memory()
|
||||||
|
|
||||||
|
info_dict[k] = v
|
||||||
|
|
||||||
|
return info_dict
|
||||||
|
|
||||||
|
|
||||||
|
class RandomPolicy(BasePolicy):
|
||||||
|
"""Policy that samples random actions from the action space."""
|
||||||
|
|
||||||
|
def __init__(self, seed: int | None = None, **kwargs: Any) -> None:
|
||||||
|
"""Initialize the random policy.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seed: Optional random seed for the action space.
|
||||||
|
**kwargs: Additional configuration parameters.
|
||||||
|
"""
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.type = 'random'
|
||||||
|
self.seed = seed
|
||||||
|
|
||||||
|
def get_action(self, obs: Any, **kwargs: Any) -> np.ndarray:
|
||||||
|
"""Get a random action from the environment's action space.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
obs: The current observation (ignored).
|
||||||
|
**kwargs: Additional parameters (ignored).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A randomly sampled action.
|
||||||
|
"""
|
||||||
|
return self.env.action_space.sample()
|
||||||
|
|
||||||
|
def set_seed(self, seed: int) -> None:
|
||||||
|
"""Set the random seed for action sampling.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seed: The seed value.
|
||||||
|
"""
|
||||||
|
if self.env is not None:
|
||||||
|
self.env.action_space.seed(seed)
|
||||||
|
|
||||||
|
|
||||||
|
class ExpertPolicy(BasePolicy):
|
||||||
|
"""Policy using expert demonstrations or heuristics."""
|
||||||
|
|
||||||
|
def __init__(self, **kwargs: Any) -> None:
|
||||||
|
"""Initialize the expert policy.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
**kwargs: Additional configuration parameters.
|
||||||
|
"""
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.type = 'expert'
|
||||||
|
|
||||||
|
def get_action(
|
||||||
|
self, obs: Any, goal_obs: Any, **kwargs: Any
|
||||||
|
) -> np.ndarray | None:
|
||||||
|
"""Get action from the expert policy.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
obs: The current observation.
|
||||||
|
goal_obs: The goal observation.
|
||||||
|
**kwargs: Additional parameters.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The expert action, or None if not available.
|
||||||
|
"""
|
||||||
|
# Implement expert policy logic here
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class FeedForwardPolicy(BasePolicy):
|
||||||
|
"""Feed-Forward Policy using a neural network model.
|
||||||
|
|
||||||
|
Actions are computed via a single forward pass through the model.
|
||||||
|
Useful for imitation learning policies like Goal-Conditioned Behavioral Cloning (GCBC).
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
model: Neural network model implementing the Actionable protocol.
|
||||||
|
process: Dictionary of data preprocessors for specific keys.
|
||||||
|
transform: Dictionary of tensor transformations (e.g., image transforms).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: Actionable,
|
||||||
|
process: dict[str, Transformable] | None = None,
|
||||||
|
transform: dict[str, Callable[[torch.Tensor], torch.Tensor]]
|
||||||
|
| None = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
"""Initialize the feed-forward policy.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Neural network model with a `get_action` method.
|
||||||
|
process: Dictionary of data preprocessors for specific keys.
|
||||||
|
transform: Dictionary of tensor transformations (e.g., image transforms).
|
||||||
|
**kwargs: Additional configuration parameters.
|
||||||
|
"""
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.type = 'feed_forward'
|
||||||
|
self.model = model.eval()
|
||||||
|
self.process = process or {}
|
||||||
|
self.transform = transform or {}
|
||||||
|
|
||||||
|
def get_action(self, info_dict: dict, **kwargs: Any) -> np.ndarray:
|
||||||
|
"""Get action via a forward pass through the neural network model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
info_dict: Current state information containing at minimum a 'goal' key.
|
||||||
|
**kwargs: Additional parameters (unused).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The selected action as a numpy array.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AssertionError: If environment not set or 'goal' not in info_dict.
|
||||||
|
"""
|
||||||
|
assert hasattr(self, 'env'), 'Environment not set for the policy'
|
||||||
|
assert 'goal' in info_dict, "'goal' must be provided in info_dict"
|
||||||
|
|
||||||
|
# Prepare the info dict (transforms and normalizes inputs)
|
||||||
|
info_dict = self._prepare_info_for_device(
|
||||||
|
info_dict, device=next(self.model.parameters()).device
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add goal_pixels key for GCBC model
|
||||||
|
if 'goal' in info_dict:
|
||||||
|
info_dict['goal_pixels'] = info_dict['goal']
|
||||||
|
|
||||||
|
# Move all tensors to the model's device
|
||||||
|
device = next(self.model.parameters()).device
|
||||||
|
info_dict = self._move_info_to_device(info_dict, device)
|
||||||
|
|
||||||
|
# Get action from model
|
||||||
|
with torch.no_grad():
|
||||||
|
action = self.model.get_action(info_dict)
|
||||||
|
|
||||||
|
# Convert to numpy
|
||||||
|
if torch.is_tensor(action):
|
||||||
|
action = action.cpu().detach().numpy()
|
||||||
|
|
||||||
|
# post-process action
|
||||||
|
if 'action' in self.process:
|
||||||
|
action = self.process['action'].inverse_transform(action)
|
||||||
|
|
||||||
|
return action
|
||||||
|
|
||||||
|
|
||||||
|
class WorldModelPolicy(BasePolicy):
|
||||||
|
"""Policy using a world model and planning solver for action selection."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
solver: Solver,
|
||||||
|
config: PlanConfig,
|
||||||
|
process: dict[str, Transformable] | None = None,
|
||||||
|
transform: dict[str, Callable[[torch.Tensor], torch.Tensor]]
|
||||||
|
| None = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
"""Initialize the world model policy.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
solver: The planning solver to use.
|
||||||
|
config: MPC planning configuration.
|
||||||
|
process: Dictionary of data preprocessors for specific keys.
|
||||||
|
transform: Dictionary of tensor transformations (e.g., image transforms).
|
||||||
|
**kwargs: Additional configuration parameters.
|
||||||
|
"""
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
self.type = 'world_model'
|
||||||
|
self.cfg = config
|
||||||
|
self.solver = solver
|
||||||
|
self.action_buffer: deque[torch.Tensor] = deque(
|
||||||
|
maxlen=self.flatten_receding_horizon
|
||||||
|
)
|
||||||
|
self.process = process or {}
|
||||||
|
self.transform = transform or {}
|
||||||
|
self._action_buffer: deque[torch.Tensor] | None = None
|
||||||
|
self._next_init: torch.Tensor | None = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def flatten_receding_horizon(self) -> int:
|
||||||
|
"""Receding horizon in environment steps (with frameskip)."""
|
||||||
|
return self.cfg.receding_horizon * self.cfg.action_block
|
||||||
|
|
||||||
|
def set_env(self, env: Any) -> None:
|
||||||
|
"""Configure the policy and solver for the given environment.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
env: The environment to associate with the policy.
|
||||||
|
"""
|
||||||
|
self.env = env
|
||||||
|
n_envs = getattr(env, 'num_envs', 1)
|
||||||
|
self.solver.configure(
|
||||||
|
action_space=env.action_space, n_envs=n_envs, config=self.cfg
|
||||||
|
)
|
||||||
|
self._action_buffer = deque(maxlen=self.flatten_receding_horizon)
|
||||||
|
|
||||||
|
assert isinstance(self.solver, Solver), (
|
||||||
|
'Solver must implement the Solver protocol'
|
||||||
|
)
|
||||||
|
|
||||||
|
def _normalize_active_mask(self, active_mask: Any) -> np.ndarray | None:
|
||||||
|
if active_mask is None:
|
||||||
|
return None
|
||||||
|
active_mask = np.asarray(active_mask, dtype=bool)
|
||||||
|
if active_mask.ndim != 1 or active_mask.shape[0] != self.env.num_envs:
|
||||||
|
raise ValueError(
|
||||||
|
f"active_mask must have shape ({self.env.num_envs},), got {active_mask.shape}"
|
||||||
|
)
|
||||||
|
return active_mask
|
||||||
|
|
||||||
|
def get_action(self, info_dict: dict, **kwargs: Any) -> np.ndarray:
|
||||||
|
"""Get action via planning with the world model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
info_dict: Current state information from the environment.
|
||||||
|
**kwargs: Additional parameters for planning.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The selected action(s) as a numpy array.
|
||||||
|
"""
|
||||||
|
assert hasattr(self, 'env'), 'Environment not set for the policy'
|
||||||
|
active_mask = self._normalize_active_mask(kwargs.get("active_mask"))
|
||||||
|
|
||||||
|
# need to replan if action buffer is empty
|
||||||
|
if len(self._action_buffer) == 0:
|
||||||
|
assert 'pixels' in info_dict, "'pixels' must be provided in info_dict"
|
||||||
|
assert 'goal' in info_dict, "'goal' must be provided in info_dict"
|
||||||
|
|
||||||
|
info_dict = self._prepare_info_for_device(
|
||||||
|
info_dict, device=self.solver.device
|
||||||
|
)
|
||||||
|
info_dict = self._move_info_to_device(info_dict, self.solver.device)
|
||||||
|
|
||||||
|
outputs = self.solver(
|
||||||
|
info_dict,
|
||||||
|
init_action=self._next_init,
|
||||||
|
active_mask=active_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
actions = outputs['actions'] # (num_envs, horizon, action_dim)
|
||||||
|
if active_mask is not None and actions.shape[0] != self.env.num_envs:
|
||||||
|
full_actions = torch.zeros(
|
||||||
|
self.env.num_envs,
|
||||||
|
actions.shape[1],
|
||||||
|
actions.shape[2],
|
||||||
|
dtype=actions.dtype,
|
||||||
|
device=actions.device,
|
||||||
|
)
|
||||||
|
full_actions[torch.as_tensor(active_mask, device=actions.device)] = actions
|
||||||
|
actions = full_actions
|
||||||
|
|
||||||
|
keep_horizon = self.cfg.receding_horizon
|
||||||
|
plan = actions[:, :keep_horizon]
|
||||||
|
rest = actions[:, keep_horizon:]
|
||||||
|
self._next_init = rest.contiguous() if self.cfg.warm_start else None
|
||||||
|
|
||||||
|
# frameskip back to timestep
|
||||||
|
plan = plan.reshape(
|
||||||
|
self.env.num_envs, self.flatten_receding_horizon, -1
|
||||||
|
).contiguous()
|
||||||
|
|
||||||
|
self._action_buffer.extend(plan.transpose(0, 1).unbind(0))
|
||||||
|
|
||||||
|
action = self._action_buffer.popleft()
|
||||||
|
action = action.reshape(*self.env.action_space.shape)
|
||||||
|
if active_mask is not None:
|
||||||
|
if torch.is_tensor(action):
|
||||||
|
inactive_mask = torch.as_tensor(
|
||||||
|
~active_mask, device=action.device, dtype=torch.bool
|
||||||
|
)
|
||||||
|
action = action.clone()
|
||||||
|
action[inactive_mask] = 0
|
||||||
|
else:
|
||||||
|
action = np.array(action, copy=True)
|
||||||
|
action[~active_mask] = 0
|
||||||
|
if torch.is_tensor(action):
|
||||||
|
action = action.detach().cpu().numpy()
|
||||||
|
else:
|
||||||
|
action = np.asarray(action)
|
||||||
|
|
||||||
|
# post-process action
|
||||||
|
if 'action' in self.process:
|
||||||
|
action = self.process['action'].inverse_transform(action)
|
||||||
|
|
||||||
|
return action # (num_envs, action_dim)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_model_with_attribute(run_name, attribute_name, cache_dir=None):
|
||||||
|
"""Helper function to load a model checkpoint and find a module with the specified attribute.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
run_name: Path or name of the model run
|
||||||
|
attribute_name: Name of the attribute to look for in the module (e.g., 'get_action', 'get_cost')
|
||||||
|
cache_dir: Optional cache directory path
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The module with the specified attribute
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If no module with the specified attribute is found
|
||||||
|
"""
|
||||||
|
if Path(run_name).exists():
|
||||||
|
run_path = Path(run_name)
|
||||||
|
else:
|
||||||
|
run_path = Path(cache_dir or swm.data.utils.get_cache_dir(), run_name)
|
||||||
|
|
||||||
|
if run_path.is_dir():
|
||||||
|
ckpt_files = list(run_path.glob('*_object.ckpt'))
|
||||||
|
ckpt_files.sort(key=lambda x: x.stat().st_ctime, reverse=True)
|
||||||
|
path = ckpt_files[0]
|
||||||
|
logging.info(f'Loading model from checkpoint: {path}')
|
||||||
|
else:
|
||||||
|
path = Path(f'{run_path}_object.ckpt')
|
||||||
|
assert path.exists(), (
|
||||||
|
f'Checkpoint path does not exist: {path}. Launch pretraining first.'
|
||||||
|
)
|
||||||
|
|
||||||
|
spt_module = torch.load(path, weights_only=False, map_location='cpu')
|
||||||
|
|
||||||
|
def scan_module(module):
|
||||||
|
if hasattr(module, attribute_name):
|
||||||
|
if isinstance(module, torch.nn.Module):
|
||||||
|
module = module.eval()
|
||||||
|
return module
|
||||||
|
for child in module.children():
|
||||||
|
result = scan_module(child)
|
||||||
|
if result is not None:
|
||||||
|
return result
|
||||||
|
return None
|
||||||
|
|
||||||
|
result = scan_module(spt_module)
|
||||||
|
if result is not None:
|
||||||
|
return result
|
||||||
|
|
||||||
|
raise RuntimeError(
|
||||||
|
f"No module with '{attribute_name}' found in the loaded world model."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def AutoActionableModel(
|
||||||
|
run_name: str, cache_dir: str | Path | None = None
|
||||||
|
) -> torch.nn.Module:
|
||||||
|
"""Load a model checkpoint and return the module with a `get_action` method.
|
||||||
|
|
||||||
|
Automatically scans the checkpoint for a module implementing the Actionable
|
||||||
|
protocol (i.e., has a `get_action` method).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
run_name: Path or name of the model run/checkpoint.
|
||||||
|
cache_dir: Optional cache directory path. Defaults to STABLEWM_HOME.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The module with a `get_action` method, set to eval mode.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If no module with `get_action` is found in the checkpoint.
|
||||||
|
"""
|
||||||
|
return _load_model_with_attribute(run_name, 'get_action', cache_dir)
|
||||||
|
|
||||||
|
|
||||||
|
def AutoCostModel(
|
||||||
|
run_name: str, cache_dir: str | Path | None = None
|
||||||
|
) -> torch.nn.Module:
|
||||||
|
"""Load a model checkpoint and return the module with a `get_cost` method.
|
||||||
|
|
||||||
|
Automatically scans the checkpoint for a module implementing a cost function
|
||||||
|
(i.e., has a `get_cost` method) for use with planning solvers.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
run_name: Path or name of the model run/checkpoint.
|
||||||
|
cache_dir: Optional cache directory path. Defaults to STABLEWM_HOME.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The module with a `get_cost` method, set to eval mode.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If no module with `get_cost` is found in the checkpoint.
|
||||||
|
"""
|
||||||
|
return _load_model_with_attribute(run_name, 'get_cost', cache_dir)
|
||||||
|
|
||||||
|
|
||||||
|
# Alias for backward compatibility and type hinting
|
||||||
|
Policy = BasePolicy
|
||||||
@@ -0,0 +1,17 @@
|
|||||||
|
from .cem import CEMSolver
|
||||||
|
from .gd import GradientSolver
|
||||||
|
from .icem import ICEMSolver
|
||||||
|
from .lagrangian import LagrangianSolver
|
||||||
|
from .mppi import MPPISolver
|
||||||
|
from .solver import Solver
|
||||||
|
from .discrete_solvers import PGDSolver
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'Solver',
|
||||||
|
'GradientSolver',
|
||||||
|
'CEMSolver',
|
||||||
|
'ICEMSolver',
|
||||||
|
'PGDSolver',
|
||||||
|
'MPPISolver',
|
||||||
|
'LagrangianSolver',
|
||||||
|
]
|
||||||
@@ -0,0 +1,277 @@
|
|||||||
|
"""Cross Entropy Method solver for model-based planning."""
|
||||||
|
|
||||||
|
import time
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import gymnasium as gym
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from gymnasium.spaces import Box
|
||||||
|
from loguru import logger as logging
|
||||||
|
|
||||||
|
from .solver import Costable
|
||||||
|
|
||||||
|
|
||||||
|
class CEMSolver:
|
||||||
|
"""Cross Entropy Method solver for action optimization.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: World model implementing the Costable protocol.
|
||||||
|
batch_size: Number of environments to process in parallel.
|
||||||
|
num_samples: Number of action candidates to sample per iteration.
|
||||||
|
var_scale: Initial variance scale for the action distribution.
|
||||||
|
n_steps: Number of CEM iterations.
|
||||||
|
topk: Number of elite samples to keep for distribution update.
|
||||||
|
device: Device for tensor computations.
|
||||||
|
seed: Random seed for reproducibility.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: Costable,
|
||||||
|
batch_size: int = 1,
|
||||||
|
num_samples: int = 300,
|
||||||
|
var_scale: float = 1,
|
||||||
|
n_steps: int = 30,
|
||||||
|
topk: int = 30,
|
||||||
|
device: str | torch.device = "cpu",
|
||||||
|
seed: int = 1234,
|
||||||
|
) -> None:
|
||||||
|
self.model = model
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.var_scale = var_scale
|
||||||
|
self.num_samples = num_samples
|
||||||
|
self.n_steps = n_steps
|
||||||
|
self.topk = topk
|
||||||
|
self.device = device
|
||||||
|
self.torch_gen = torch.Generator(device=device).manual_seed(seed)
|
||||||
|
|
||||||
|
def configure(self, *, action_space: gym.Space, n_envs: int, config: Any) -> None:
|
||||||
|
"""Configure the solver with environment specifications."""
|
||||||
|
self._action_space = action_space
|
||||||
|
self._n_envs = n_envs
|
||||||
|
self._config = config
|
||||||
|
self._action_dim = int(np.prod(action_space.shape[1:]))
|
||||||
|
self._configured = True
|
||||||
|
|
||||||
|
if not isinstance(action_space, Box):
|
||||||
|
logging.warning(f"Action space is discrete, got {type(action_space)}. CEMSolver may not work as expected.")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def n_envs(self) -> int:
|
||||||
|
"""Number of parallel environments."""
|
||||||
|
return self._n_envs
|
||||||
|
|
||||||
|
@property
|
||||||
|
def action_dim(self) -> int:
|
||||||
|
"""Flattened action dimension including action_block grouping."""
|
||||||
|
return self._action_dim * self._config.action_block
|
||||||
|
|
||||||
|
@property
|
||||||
|
def horizon(self) -> int:
|
||||||
|
"""Planning horizon in timesteps."""
|
||||||
|
return self._config.horizon
|
||||||
|
|
||||||
|
def __call__(self, *args: Any, **kwargs: Any) -> dict:
|
||||||
|
"""Make solver callable, forwarding to solve()."""
|
||||||
|
return self.solve(*args, **kwargs)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _normalize_active_mask(
|
||||||
|
active_mask: torch.Tensor | np.ndarray | None,
|
||||||
|
n_envs: int,
|
||||||
|
device: torch.device,
|
||||||
|
) -> torch.Tensor | None:
|
||||||
|
if active_mask is None:
|
||||||
|
return None
|
||||||
|
if not torch.is_tensor(active_mask):
|
||||||
|
active_mask = torch.as_tensor(active_mask, dtype=torch.bool, device=device)
|
||||||
|
else:
|
||||||
|
active_mask = active_mask.to(device=device, dtype=torch.bool)
|
||||||
|
if active_mask.ndim != 1 or active_mask.shape[0] != n_envs:
|
||||||
|
raise ValueError(
|
||||||
|
f"active_mask must have shape ({n_envs},), got {tuple(active_mask.shape)}"
|
||||||
|
)
|
||||||
|
return active_mask
|
||||||
|
|
||||||
|
def init_action_distrib(
|
||||||
|
self, actions: torch.Tensor | None = None
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Initialize the action distribution parameters (mean and variance)."""
|
||||||
|
device = torch.device(self.device)
|
||||||
|
var = self.var_scale * torch.ones(
|
||||||
|
[self.n_envs, self.horizon, self.action_dim],
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
mean = (
|
||||||
|
torch.zeros([self.n_envs, 0, self.action_dim], device=device)
|
||||||
|
if actions is None
|
||||||
|
else actions
|
||||||
|
)
|
||||||
|
|
||||||
|
remaining = self.horizon - mean.shape[1]
|
||||||
|
if remaining > 0:
|
||||||
|
new_mean = torch.zeros(
|
||||||
|
[self.n_envs, remaining, self.action_dim],
|
||||||
|
device=mean.device,
|
||||||
|
)
|
||||||
|
mean = torch.cat([mean, new_mean], dim=1)
|
||||||
|
|
||||||
|
return mean, var
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def solve(
|
||||||
|
self,
|
||||||
|
info_dict: dict,
|
||||||
|
init_action: torch.Tensor | None = None,
|
||||||
|
active_mask: torch.Tensor | np.ndarray | None = None,
|
||||||
|
) -> dict:
|
||||||
|
"""Solve the planning problem using Cross Entropy Method."""
|
||||||
|
start_time = time.time()
|
||||||
|
outputs = {
|
||||||
|
"costs": [],
|
||||||
|
"mean": [], # History of means
|
||||||
|
"var": [], # History of vars
|
||||||
|
}
|
||||||
|
|
||||||
|
# -- initialize the action distribution globally
|
||||||
|
mean, var = self.init_action_distrib(init_action)
|
||||||
|
if mean.device != torch.device(self.device):
|
||||||
|
mean = mean.to(self.device, non_blocking=True)
|
||||||
|
if var.device != torch.device(self.device):
|
||||||
|
var = var.to(self.device, non_blocking=True)
|
||||||
|
active_mask = self._normalize_active_mask(
|
||||||
|
active_mask, self.n_envs, torch.device(self.device)
|
||||||
|
)
|
||||||
|
|
||||||
|
if active_mask is not None and not torch.any(active_mask):
|
||||||
|
return {
|
||||||
|
"costs": [],
|
||||||
|
"actions": mean.detach(),
|
||||||
|
"mean": [mean.detach()],
|
||||||
|
"var": [var.detach()],
|
||||||
|
}
|
||||||
|
|
||||||
|
total_envs = self.n_envs
|
||||||
|
|
||||||
|
# --- Iterate over batches ---
|
||||||
|
for start_idx in range(0, total_envs, self.batch_size):
|
||||||
|
end_idx = min(start_idx + self.batch_size, total_envs)
|
||||||
|
current_bs = end_idx - start_idx
|
||||||
|
|
||||||
|
# Slice Distribution Parameters for current batch
|
||||||
|
batch_mean = mean[start_idx:end_idx]
|
||||||
|
batch_var = var[start_idx:end_idx]
|
||||||
|
|
||||||
|
# Expand Info Dict for current batch
|
||||||
|
expanded_infos = {}
|
||||||
|
for k, v in info_dict.items():
|
||||||
|
# v is shape (n_envs, ...)
|
||||||
|
# Slice batch
|
||||||
|
v_batch = v[start_idx:end_idx]
|
||||||
|
if torch.is_tensor(v):
|
||||||
|
if v_batch.device != self.device:
|
||||||
|
v_batch = v_batch.to(self.device, non_blocking=True)
|
||||||
|
# Add sample dim: (batch, 1, ...)
|
||||||
|
v_batch = v_batch.unsqueeze(1)
|
||||||
|
# Expand: (batch, num_samples, ...)
|
||||||
|
v_batch = v_batch.expand(current_bs, self.num_samples, *v_batch.shape[2:])
|
||||||
|
elif isinstance(v, np.ndarray):
|
||||||
|
v_batch = np.repeat(v_batch[:, None, ...], self.num_samples, axis=1)
|
||||||
|
expanded_infos[k] = v_batch
|
||||||
|
|
||||||
|
if active_mask is not None:
|
||||||
|
batch_mask = active_mask[start_idx:end_idx]
|
||||||
|
if not torch.any(batch_mask):
|
||||||
|
outputs["costs"].append(
|
||||||
|
torch.full((current_bs,), float("nan"), device=self.device)
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
active_local = torch.nonzero(batch_mask, as_tuple=False).squeeze(1)
|
||||||
|
active_local_np = active_local.detach().cpu().numpy()
|
||||||
|
batch_mean = batch_mean[active_local]
|
||||||
|
batch_var = batch_var[active_local]
|
||||||
|
expanded_infos = {
|
||||||
|
k: (v[active_local] if torch.is_tensor(v) else v[active_local_np])
|
||||||
|
for k, v in expanded_infos.items()
|
||||||
|
}
|
||||||
|
current_bs = int(active_local.numel())
|
||||||
|
else:
|
||||||
|
active_local = None
|
||||||
|
|
||||||
|
# Optimization Loop
|
||||||
|
final_batch_cost = None
|
||||||
|
batch_indices = torch.arange(current_bs, device=self.device).unsqueeze(1)
|
||||||
|
|
||||||
|
for step in range(self.n_steps):
|
||||||
|
# Sample action sequences: (Batch, Num_Samples, Horizon, Dim)
|
||||||
|
candidates = torch.randn(
|
||||||
|
current_bs,
|
||||||
|
self.num_samples,
|
||||||
|
self.horizon,
|
||||||
|
self.action_dim,
|
||||||
|
generator=self.torch_gen,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Scale and shift: (Batch, N, H, D) * (Batch, 1, H, D) + (Batch, 1, H, D)
|
||||||
|
candidates = candidates * batch_var.unsqueeze(1) + batch_mean.unsqueeze(1)
|
||||||
|
|
||||||
|
# Force the first sample to be the current mean
|
||||||
|
candidates[:, 0] = batch_mean
|
||||||
|
|
||||||
|
current_info = expanded_infos.copy()
|
||||||
|
|
||||||
|
# Evaluate candidates
|
||||||
|
costs = self.model.get_cost(current_info, candidates)
|
||||||
|
|
||||||
|
assert isinstance(costs, torch.Tensor), f"Expected cost to be a torch.Tensor, got {type(costs)}"
|
||||||
|
assert costs.ndim == 2 and costs.shape[0] == current_bs and costs.shape[1] == self.num_samples, (
|
||||||
|
f"Expected cost to be of shape ({current_bs}, {self.num_samples}), got {costs.shape}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Select Top-K
|
||||||
|
# topk_vals: (Batch, K), topk_inds: (Batch, K)
|
||||||
|
topk_vals, topk_inds = torch.topk(costs, k=self.topk, dim=1, largest=False)
|
||||||
|
|
||||||
|
# Gather Top-K Candidates
|
||||||
|
# We need to select the specific candidates corresponding to topk_inds
|
||||||
|
# Indexing: candidates[batch_idx, sample_idx]
|
||||||
|
# Result shape: (Batch, K, Horizon, Dim)
|
||||||
|
topk_candidates = candidates[batch_indices, topk_inds]
|
||||||
|
|
||||||
|
# Update Mean and Variance based on Top-K
|
||||||
|
batch_mean = topk_candidates.mean(dim=1)
|
||||||
|
batch_var = topk_candidates.std(dim=1)
|
||||||
|
|
||||||
|
# Update final cost for logging
|
||||||
|
# We average the cost of the top elites
|
||||||
|
final_batch_cost = topk_vals.mean(dim=1).detach()
|
||||||
|
|
||||||
|
# Write results back to global storage
|
||||||
|
if active_mask is not None:
|
||||||
|
global_indices = start_idx + active_local
|
||||||
|
mean[global_indices] = batch_mean
|
||||||
|
var[global_indices] = batch_var
|
||||||
|
batch_costs = torch.full(
|
||||||
|
(end_idx - start_idx,), float("nan"), device=self.device
|
||||||
|
)
|
||||||
|
batch_costs[active_local] = final_batch_cost
|
||||||
|
else:
|
||||||
|
mean[start_idx:end_idx] = batch_mean
|
||||||
|
var[start_idx:end_idx] = batch_var
|
||||||
|
batch_costs = final_batch_cost
|
||||||
|
|
||||||
|
# Store history/metadata
|
||||||
|
outputs["costs"].append(batch_costs)
|
||||||
|
|
||||||
|
if outputs["costs"]:
|
||||||
|
outputs["costs"] = torch.cat(outputs["costs"]).cpu().tolist()
|
||||||
|
else:
|
||||||
|
outputs["costs"] = []
|
||||||
|
outputs["actions"] = mean.detach()
|
||||||
|
outputs["mean"] = [mean.detach()]
|
||||||
|
outputs["var"] = [var.detach()]
|
||||||
|
|
||||||
|
print(f"CEM solve time: {time.time() - start_time:.4f} seconds")
|
||||||
|
return outputs
|
||||||
@@ -0,0 +1,256 @@
|
|||||||
|
"""Projected Gradient Descent solver for discrete action spaces."""
|
||||||
|
|
||||||
|
import time
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import gymnasium as gym
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from gymnasium.spaces import Discrete
|
||||||
|
|
||||||
|
from .solver import Costable
|
||||||
|
|
||||||
|
|
||||||
|
class PGDSolver(torch.nn.Module):
|
||||||
|
"""Projected Gradient Descent solver for discrete action optimization.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: World model implementing the Costable protocol.
|
||||||
|
n_steps: Number of gradient descent iterations.
|
||||||
|
batch_size: Number of environments to process in parallel.
|
||||||
|
var_scale: Initial variance scale for action perturbations.
|
||||||
|
num_samples: Number of action samples to optimize in parallel.
|
||||||
|
action_noise: Noise added to actions during optimization.
|
||||||
|
device: Device for tensor computations.
|
||||||
|
seed: Random seed for reproducibility.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: Costable,
|
||||||
|
n_steps: int,
|
||||||
|
batch_size: int | None = None,
|
||||||
|
var_scale: float = 1,
|
||||||
|
num_samples: int = 1,
|
||||||
|
action_noise: float = 0.0,
|
||||||
|
device: str | torch.device = "cpu",
|
||||||
|
seed: int = 1234,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.model = model
|
||||||
|
self.n_steps = n_steps
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.num_samples = num_samples
|
||||||
|
self.var_scale = var_scale
|
||||||
|
self.action_noise = action_noise
|
||||||
|
self.device = device
|
||||||
|
self.torch_gen = torch.Generator(device=device).manual_seed(seed)
|
||||||
|
|
||||||
|
self._configured = False
|
||||||
|
self._n_envs = None
|
||||||
|
self._action_dim = None
|
||||||
|
self._action_simplex_dim = None
|
||||||
|
self._config = None
|
||||||
|
|
||||||
|
def configure(self, *, action_space: gym.Space, n_envs: int, config: Any) -> None:
|
||||||
|
"""Configure the solver with environment specifications."""
|
||||||
|
assert isinstance(action_space, Discrete), f"Action space must be discrete, got {type(action_space)}"
|
||||||
|
|
||||||
|
self._action_space = action_space
|
||||||
|
self._n_envs = n_envs
|
||||||
|
self._config = config
|
||||||
|
self._action_dim = int(np.prod(action_space.shape[1:]))
|
||||||
|
self._action_simplex_dim = int(action_space.n)
|
||||||
|
self._configured = True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def n_envs(self) -> int:
|
||||||
|
"""Number of parallel environments."""
|
||||||
|
return self._n_envs
|
||||||
|
|
||||||
|
@property
|
||||||
|
def action_dim(self) -> int:
|
||||||
|
"""Flattened action dimension including action_block grouping."""
|
||||||
|
return self._action_dim * self._config.action_block
|
||||||
|
|
||||||
|
@property
|
||||||
|
def action_simplex_dim(self) -> int:
|
||||||
|
"""Simplex dimension for discrete action probabilities."""
|
||||||
|
return self._action_simplex_dim * self._config.action_block
|
||||||
|
|
||||||
|
@property
|
||||||
|
def horizon(self) -> int:
|
||||||
|
"""Planning horizon in timesteps."""
|
||||||
|
return self._config.horizon
|
||||||
|
|
||||||
|
def __call__(self, *args: Any, **kwargs: Any) -> dict:
|
||||||
|
"""Make solver callable, forwarding to solve()."""
|
||||||
|
return self.solve(*args, **kwargs)
|
||||||
|
|
||||||
|
def init_action(
|
||||||
|
self, actions: torch.Tensor | None = None, from_scalar: bool = False
|
||||||
|
) -> None:
|
||||||
|
"""Initialize the action tensor for optimization."""
|
||||||
|
if actions is None:
|
||||||
|
actions = torch.zeros((self._n_envs, 0, self.action_simplex_dim))
|
||||||
|
elif from_scalar:
|
||||||
|
# convert scalar to one-hot
|
||||||
|
actions = torch.nn.functional.one_hot(actions, num_classes=self._action_simplex_dim).to(torch.float32)
|
||||||
|
# merge action_block dim
|
||||||
|
actions = actions.reshape(*actions.shape[:-2], self.action_simplex_dim)
|
||||||
|
assert (
|
||||||
|
actions.shape[0] == self._n_envs
|
||||||
|
and actions.shape[1] <= self.horizon
|
||||||
|
and actions.shape[2] == self.action_simplex_dim
|
||||||
|
)
|
||||||
|
|
||||||
|
# fill remaining action
|
||||||
|
remaining = self.horizon - actions.shape[1]
|
||||||
|
|
||||||
|
if remaining > 0:
|
||||||
|
new_actions = torch.zeros(self._n_envs, remaining, self.action_simplex_dim)
|
||||||
|
actions = torch.cat([actions, new_actions], dim=1).to(self.device)
|
||||||
|
|
||||||
|
actions = actions.unsqueeze(1).repeat_interleave(self.num_samples, dim=1) # add sample dim
|
||||||
|
actions[:, 1:] += (
|
||||||
|
torch.randn(actions[:, 1:].shape, generator=self.torch_gen, device=self.device) * self.var_scale
|
||||||
|
) # add small noise to all samples except the first one
|
||||||
|
|
||||||
|
# reset actions
|
||||||
|
if hasattr(self, "init"):
|
||||||
|
self.init.copy_(actions)
|
||||||
|
else:
|
||||||
|
self.register_parameter("init", torch.nn.Parameter(actions))
|
||||||
|
|
||||||
|
def solve(
|
||||||
|
self,
|
||||||
|
info_dict: dict,
|
||||||
|
init_action: torch.Tensor | None = None,
|
||||||
|
from_scalar: bool = False,
|
||||||
|
) -> dict:
|
||||||
|
"""Solve the planning problem using projected gradient descent."""
|
||||||
|
start_time = time.time()
|
||||||
|
outputs = {
|
||||||
|
"cost": [], # Will store list of cost histories per batch
|
||||||
|
"actions": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
self.init_action(init_action, from_scalar=from_scalar)
|
||||||
|
|
||||||
|
# Determine batch size (default to all envs if not specified which can cause memory issues)
|
||||||
|
batch_size = self.batch_size if self.batch_size is not None else self.n_envs
|
||||||
|
total_envs = self.n_envs
|
||||||
|
|
||||||
|
# Lists to hold results from each batch to be concatenated later
|
||||||
|
batch_top_actions_list = []
|
||||||
|
|
||||||
|
# --- Outer Loop: Iterate over batches ---
|
||||||
|
for start_idx in range(0, total_envs, batch_size):
|
||||||
|
end_idx = min(start_idx + batch_size, total_envs)
|
||||||
|
current_bs = end_idx - start_idx
|
||||||
|
|
||||||
|
batch_init = self.init[start_idx:end_idx].clone().detach()
|
||||||
|
batch_init.requires_grad = True
|
||||||
|
|
||||||
|
optim = torch.optim.SGD([batch_init], lr=1.0)
|
||||||
|
|
||||||
|
# Prepare Batch Infos
|
||||||
|
# Slice the input info_dict and then expand dimensions
|
||||||
|
expanded_infos = {}
|
||||||
|
for k, v in info_dict.items():
|
||||||
|
# Slice the data for the current batch indices
|
||||||
|
# Assumes input data dim 0 corresponds to n_envs
|
||||||
|
if torch.is_tensor(v):
|
||||||
|
batch_v = v[start_idx:end_idx]
|
||||||
|
batch_v = batch_v.unsqueeze(1)
|
||||||
|
batch_v = batch_v.expand(current_bs, self.num_samples, *batch_v.shape[2:])
|
||||||
|
elif isinstance(v, np.ndarray):
|
||||||
|
batch_v = v[start_idx:end_idx]
|
||||||
|
batch_v = np.repeat(batch_v[:, None, ...], self.num_samples, axis=1)
|
||||||
|
expanded_infos[k] = batch_v
|
||||||
|
|
||||||
|
# Perform Gradient Descent for this batch
|
||||||
|
batch_cost_history = []
|
||||||
|
|
||||||
|
for step in range(self.n_steps):
|
||||||
|
current_info = expanded_infos.copy()
|
||||||
|
|
||||||
|
# Calculate cost using the batch parameter
|
||||||
|
costs = self.model.get_cost(current_info, batch_init)
|
||||||
|
|
||||||
|
assert isinstance(costs, torch.Tensor), f"Got {type(costs)} cost, expect torch.Tensor"
|
||||||
|
assert costs.ndim == 2 and costs.shape[0] == current_bs and costs.shape[1] == self.num_samples, (
|
||||||
|
f"Cost should be of shape ({current_bs}, {self.num_samples}), got {costs.shape}"
|
||||||
|
)
|
||||||
|
assert costs.requires_grad, "Cost must requires_grad for PGD solver."
|
||||||
|
|
||||||
|
cost = costs.sum() # Sum cost for this batch
|
||||||
|
cost.backward()
|
||||||
|
optim.step()
|
||||||
|
optim.zero_grad(set_to_none=True)
|
||||||
|
|
||||||
|
# Add noise
|
||||||
|
if self.action_noise > 0:
|
||||||
|
batch_init.data += torch.randn(batch_init.shape, generator=self.torch_gen) * self.action_noise
|
||||||
|
|
||||||
|
# projection onto simplex
|
||||||
|
with torch.no_grad():
|
||||||
|
batch_init.copy_(self._project_action_simplex(batch_init))
|
||||||
|
|
||||||
|
batch_cost_history.append(cost.item())
|
||||||
|
|
||||||
|
# Store cost history for this batch
|
||||||
|
outputs["cost"].append(batch_cost_history)
|
||||||
|
|
||||||
|
# Update the global self.init with the optimized batch values
|
||||||
|
with torch.no_grad():
|
||||||
|
self.init[start_idx:end_idx] = batch_init
|
||||||
|
|
||||||
|
top_idx = torch.argsort(costs, dim=1)[:, 0]
|
||||||
|
batch_indices = torch.arange(current_bs)
|
||||||
|
|
||||||
|
top_actions_batch = batch_init[batch_indices, top_idx]
|
||||||
|
|
||||||
|
# convert one-hot back to discrete actions
|
||||||
|
top_actions_batch = self._factor_action_block(top_actions_batch).argmax(dim=-1)
|
||||||
|
batch_top_actions_list.append(top_actions_batch.detach().cpu())
|
||||||
|
|
||||||
|
# Concatenate all batch results
|
||||||
|
outputs["actions"] = torch.cat(batch_top_actions_list, dim=0)
|
||||||
|
end_time = time.time()
|
||||||
|
print(f"PGDSolver.solve completed in {end_time - start_time:.4f} seconds.")
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
def _factor_action_block(self, actions: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Factor the action block dimension from action_simplex_dim."""
|
||||||
|
original_shape = actions.shape
|
||||||
|
action_block = self._config.action_block
|
||||||
|
simplex_dim = self._action_simplex_dim
|
||||||
|
return actions.reshape(*original_shape[:-1], action_block, simplex_dim)
|
||||||
|
|
||||||
|
def _project_action_simplex(self, actions: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Project the action onto the probability simplex."""
|
||||||
|
original_shape = actions.shape
|
||||||
|
|
||||||
|
s = self._factor_action_block(actions).reshape(-1, self._action_simplex_dim)
|
||||||
|
|
||||||
|
mu, _ = torch.sort(s, descending=True, dim=-1)
|
||||||
|
cumulative = mu.cumsum(dim=-1)
|
||||||
|
|
||||||
|
d = s.size(-1)
|
||||||
|
indices = torch.arange(1, d + 1, device=s.device, dtype=s.dtype)
|
||||||
|
|
||||||
|
threshold = (cumulative - 1) / indices
|
||||||
|
|
||||||
|
cond = (mu > threshold).to(torch.int32)
|
||||||
|
rho = cond.cumsum(dim=-1)
|
||||||
|
valid_rho = rho * cond
|
||||||
|
rho_max = valid_rho.max(dim=-1, keepdim=True)[0]
|
||||||
|
|
||||||
|
rho_min = torch.clamp(rho_max, min=1)
|
||||||
|
psi = (cumulative.gather(-1, rho_min - 1) - 1) / rho_min
|
||||||
|
|
||||||
|
projected = torch.clamp(s - psi, min=0.0).reshape(original_shape)
|
||||||
|
return projected
|
||||||
@@ -0,0 +1,252 @@
|
|||||||
|
"""Gradient-based solver for model-based planning."""
|
||||||
|
|
||||||
|
import time
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import gymnasium as gym
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from gymnasium.spaces import Box
|
||||||
|
from loguru import logger as logging
|
||||||
|
|
||||||
|
from .solver import Costable
|
||||||
|
|
||||||
|
|
||||||
|
class GradientSolver(torch.nn.Module):
|
||||||
|
"""Gradient-based solver using backpropagation through the world model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: World model implementing the Costable protocol.
|
||||||
|
n_steps: Number of gradient descent iterations.
|
||||||
|
batch_size: Number of environments to process in parallel.
|
||||||
|
var_scale: Initial variance scale for action perturbations.
|
||||||
|
num_samples: Number of action samples to optimize in parallel.
|
||||||
|
action_noise: Noise added to actions during optimization.
|
||||||
|
device: Device for tensor computations.
|
||||||
|
seed: Random seed for reproducibility.
|
||||||
|
optimizer_cls: PyTorch optimizer class to use.
|
||||||
|
optimizer_kwargs: Keyword arguments for the optimizer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: Costable,
|
||||||
|
n_steps: int,
|
||||||
|
batch_size: int | None = None,
|
||||||
|
var_scale: float = 1,
|
||||||
|
num_samples: int = 1,
|
||||||
|
action_noise: float = 0.0,
|
||||||
|
device: str | torch.device = 'cpu',
|
||||||
|
seed: int = 1234,
|
||||||
|
optimizer_cls: type[torch.optim.Optimizer] = torch.optim.SGD,
|
||||||
|
optimizer_kwargs: dict | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.model = model
|
||||||
|
self.n_steps = n_steps
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.num_samples = num_samples
|
||||||
|
self.var_scale = var_scale
|
||||||
|
self.action_noise = action_noise
|
||||||
|
self.device = device
|
||||||
|
self.torch_gen = torch.Generator(device=device).manual_seed(seed)
|
||||||
|
|
||||||
|
self.optimizer_cls = optimizer_cls
|
||||||
|
self.optimizer_kwargs = (
|
||||||
|
optimizer_kwargs if optimizer_kwargs is not None else {'lr': 1.0}
|
||||||
|
)
|
||||||
|
|
||||||
|
self._configured = False
|
||||||
|
self._n_envs = None
|
||||||
|
self._action_dim = None
|
||||||
|
self._config = None
|
||||||
|
|
||||||
|
def configure(
|
||||||
|
self, *, action_space: gym.Space, n_envs: int, config: Any
|
||||||
|
) -> None:
|
||||||
|
"""Configure the solver with environment specifications."""
|
||||||
|
self._action_space = action_space
|
||||||
|
self._n_envs = n_envs
|
||||||
|
self._config = config
|
||||||
|
self._action_dim = int(np.prod(action_space.shape[1:]))
|
||||||
|
self._configured = True
|
||||||
|
|
||||||
|
if not isinstance(action_space, Box):
|
||||||
|
logging.warning(
|
||||||
|
f'Action space is discrete, got {type(action_space)}. GradientSolver may not work as expected.'
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def n_envs(self) -> int:
|
||||||
|
"""Number of parallel environments."""
|
||||||
|
return self._n_envs
|
||||||
|
|
||||||
|
@property
|
||||||
|
def action_dim(self) -> int:
|
||||||
|
"""Flattened action dimension including action_block grouping."""
|
||||||
|
return self._action_dim * self._config.action_block
|
||||||
|
|
||||||
|
@property
|
||||||
|
def horizon(self) -> int:
|
||||||
|
"""Planning horizon in timesteps."""
|
||||||
|
return self._config.horizon
|
||||||
|
|
||||||
|
def __call__(self, *args: Any, **kwargs: Any) -> dict:
|
||||||
|
"""Make solver callable, forwarding to solve()."""
|
||||||
|
return self.solve(*args, **kwargs)
|
||||||
|
|
||||||
|
def init_action(self, actions: torch.Tensor | None = None) -> None:
|
||||||
|
"""Initialize the action tensor for optimization."""
|
||||||
|
device = torch.device(self.device)
|
||||||
|
if actions is None:
|
||||||
|
actions = torch.zeros(
|
||||||
|
(self._n_envs, 0, self.action_dim), device=device
|
||||||
|
)
|
||||||
|
elif actions.device != device:
|
||||||
|
actions = actions.to(device, non_blocking=True)
|
||||||
|
|
||||||
|
# fill remaining action
|
||||||
|
remaining = self.horizon - actions.shape[1]
|
||||||
|
|
||||||
|
if remaining > 0:
|
||||||
|
new_actions = torch.zeros(
|
||||||
|
self._n_envs,
|
||||||
|
remaining,
|
||||||
|
self.action_dim,
|
||||||
|
device=actions.device,
|
||||||
|
dtype=actions.dtype,
|
||||||
|
)
|
||||||
|
actions = torch.cat([actions, new_actions], dim=1)
|
||||||
|
|
||||||
|
actions = actions.unsqueeze(1).repeat_interleave(
|
||||||
|
self.num_samples, dim=1
|
||||||
|
) # add sample dim
|
||||||
|
actions[:, 1:] += (
|
||||||
|
torch.randn(
|
||||||
|
actions[:, 1:].shape,
|
||||||
|
generator=self.torch_gen,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
* self.var_scale
|
||||||
|
) # add small noise to all samples except the first one
|
||||||
|
|
||||||
|
# reset actions
|
||||||
|
if hasattr(self, 'init'):
|
||||||
|
self.init.copy_(actions)
|
||||||
|
else:
|
||||||
|
self.register_parameter('init', torch.nn.Parameter(actions))
|
||||||
|
|
||||||
|
def solve(
|
||||||
|
self, info_dict: dict, init_action: torch.Tensor | None = None
|
||||||
|
) -> dict:
|
||||||
|
"""Solve the planning problem using gradient descent."""
|
||||||
|
start_time = time.time()
|
||||||
|
outputs = {
|
||||||
|
'cost': [], # Will store list of cost histories per batch
|
||||||
|
'actions': None,
|
||||||
|
}
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
self.init_action(init_action)
|
||||||
|
|
||||||
|
# Determine batch size (default to all envs if not specified which can cause memory issues)
|
||||||
|
batch_size = (
|
||||||
|
self.batch_size if self.batch_size is not None else self.n_envs
|
||||||
|
)
|
||||||
|
total_envs = self.n_envs
|
||||||
|
|
||||||
|
# Lists to hold results from each batch to be concatenated later
|
||||||
|
batch_top_actions_list = []
|
||||||
|
|
||||||
|
# --- Outer Loop: Iterate over batches ---
|
||||||
|
for start_idx in range(0, total_envs, batch_size):
|
||||||
|
end_idx = min(start_idx + batch_size, total_envs)
|
||||||
|
current_bs = end_idx - start_idx
|
||||||
|
|
||||||
|
batch_init = self.init[start_idx:end_idx].clone().detach()
|
||||||
|
batch_init.requires_grad = True
|
||||||
|
|
||||||
|
# We initialize the optimizer class passed in __init__ with the kwargs
|
||||||
|
optim = self.optimizer_cls([batch_init], **self.optimizer_kwargs)
|
||||||
|
|
||||||
|
# Prepare Batch Infos
|
||||||
|
# Slice the input info_dict and then expand dimensions
|
||||||
|
expanded_infos = {}
|
||||||
|
for k, v in info_dict.items():
|
||||||
|
# Slice the data for the current batch indices
|
||||||
|
# Assumes input data dim 0 corresponds to n_envs
|
||||||
|
if torch.is_tensor(v):
|
||||||
|
batch_v = v[start_idx:end_idx]
|
||||||
|
batch_v = batch_v.unsqueeze(1)
|
||||||
|
batch_v = batch_v.expand(
|
||||||
|
current_bs, self.num_samples, *batch_v.shape[2:]
|
||||||
|
)
|
||||||
|
elif isinstance(v, np.ndarray):
|
||||||
|
batch_v = v[start_idx:end_idx]
|
||||||
|
batch_v = np.repeat(
|
||||||
|
batch_v[:, None, ...], self.num_samples, axis=1
|
||||||
|
)
|
||||||
|
expanded_infos[k] = batch_v
|
||||||
|
|
||||||
|
final_batch_cost = None
|
||||||
|
|
||||||
|
for step in range(self.n_steps):
|
||||||
|
current_info = expanded_infos.copy()
|
||||||
|
|
||||||
|
# Calculate cost using the batch parameter
|
||||||
|
costs = self.model.get_cost(current_info, batch_init)
|
||||||
|
|
||||||
|
assert isinstance(costs, torch.Tensor), (
|
||||||
|
f'Got {type(costs)} cost, expect torch.Tensor'
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
costs.ndim == 2
|
||||||
|
and costs.shape[0] == current_bs
|
||||||
|
and costs.shape[1] == self.num_samples
|
||||||
|
), (
|
||||||
|
f'Cost should be of shape ({current_bs}, {self.num_samples}), got {costs.shape}'
|
||||||
|
)
|
||||||
|
assert costs.requires_grad, (
|
||||||
|
'Cost must requires_grad for GD solver.'
|
||||||
|
)
|
||||||
|
|
||||||
|
cost = costs.sum() # Sum cost for this batch
|
||||||
|
cost.backward()
|
||||||
|
optim.step()
|
||||||
|
optim.zero_grad(set_to_none=True)
|
||||||
|
|
||||||
|
# Add noise
|
||||||
|
if self.action_noise > 0:
|
||||||
|
batch_init.data += (
|
||||||
|
torch.randn(
|
||||||
|
batch_init.shape,
|
||||||
|
generator=self.torch_gen,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
* self.action_noise
|
||||||
|
)
|
||||||
|
|
||||||
|
final_batch_cost = costs.detach().min(dim=1).values
|
||||||
|
|
||||||
|
# Store cost history for this batch
|
||||||
|
outputs['cost'].append(final_batch_cost)
|
||||||
|
|
||||||
|
# Update the global self.init with the optimized batch values
|
||||||
|
with torch.no_grad():
|
||||||
|
self.init[start_idx:end_idx] = batch_init
|
||||||
|
|
||||||
|
top_idx = costs.argmin(dim=1)
|
||||||
|
batch_indices = torch.arange(current_bs, device=self.device)
|
||||||
|
|
||||||
|
top_actions_batch = batch_init[batch_indices, top_idx]
|
||||||
|
batch_top_actions_list.append(top_actions_batch.detach())
|
||||||
|
|
||||||
|
# Concatenate all batch results
|
||||||
|
outputs['actions'] = torch.cat(batch_top_actions_list, dim=0)
|
||||||
|
outputs['cost'] = torch.cat(outputs['cost']).cpu().tolist()
|
||||||
|
end_time = time.time()
|
||||||
|
print(
|
||||||
|
f'GradientSolver.solve completed in {end_time - start_time:.4f} seconds.'
|
||||||
|
)
|
||||||
|
|
||||||
|
return outputs
|
||||||
@@ -0,0 +1,219 @@
|
|||||||
|
"""Improved Cross Entropy Method (iCEM) solver for model-based planning."""
|
||||||
|
|
||||||
|
import time
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import gymnasium as gym
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from gymnasium.spaces import Box
|
||||||
|
from loguru import logger as logging
|
||||||
|
|
||||||
|
from .solver import Costable
|
||||||
|
|
||||||
|
|
||||||
|
class ICEMSolver:
|
||||||
|
"""Improved Cross Entropy Method (iCEM) solver with colored noise and elite retention.
|
||||||
|
iCEM improves the sample efficiency over standard CEM and was introduced by
|
||||||
|
[1] for real-time planning.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: World model implementing the Costable protocol.
|
||||||
|
batch_size: Number of environments to process in parallel.
|
||||||
|
num_samples: Number of action candidates to sample per iteration.
|
||||||
|
var_scale: Initial variance scale for the action distribution.
|
||||||
|
n_steps: Number of CEM iterations.
|
||||||
|
topk: Number of elite samples to keep for distribution update.
|
||||||
|
noise_beta: Colored noise exponent. 0 = white (standard CEM), >0 = more low-frequency noise.
|
||||||
|
alpha: Momentum for mean/std EMA update.
|
||||||
|
n_elite_keep: Number of elites carried from previous iteration.
|
||||||
|
return_mean: If False, return best single trajectory instead of mean.
|
||||||
|
device: Device for tensor computations.
|
||||||
|
seed: Random seed for reproducibility.
|
||||||
|
|
||||||
|
[1] C. Pinneri, S. Sawant, S. Blaes, J. Achterhold, J. Stueckler, M. Rolinek and
|
||||||
|
G, Martius, Georg. "Sample-efficient Cross-Entropy Method for Real-time Planning".
|
||||||
|
Conference on Robot Learning, 2020.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: Costable,
|
||||||
|
batch_size: int = 1,
|
||||||
|
num_samples: int = 300,
|
||||||
|
var_scale: float = 1,
|
||||||
|
n_steps: int = 30,
|
||||||
|
topk: int = 30,
|
||||||
|
noise_beta: float = 2.0,
|
||||||
|
alpha: float = 0.1,
|
||||||
|
n_elite_keep: int = 5,
|
||||||
|
return_mean: bool = True,
|
||||||
|
device: str | torch.device = "cpu",
|
||||||
|
seed: int = 1234,
|
||||||
|
) -> None:
|
||||||
|
self.model = model
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.var_scale = var_scale
|
||||||
|
self.num_samples = num_samples
|
||||||
|
self.n_steps = n_steps
|
||||||
|
self.topk = topk
|
||||||
|
self.noise_beta = noise_beta
|
||||||
|
self.alpha = alpha
|
||||||
|
self.n_elite_keep = n_elite_keep
|
||||||
|
self.return_mean = return_mean
|
||||||
|
self.device = device
|
||||||
|
self.torch_gen = torch.Generator(device=device).manual_seed(seed)
|
||||||
|
|
||||||
|
def configure(self, *, action_space: gym.Space, n_envs: int, config: Any) -> None:
|
||||||
|
"""Configure the solver with environment specifications."""
|
||||||
|
self._action_space = action_space
|
||||||
|
self._n_envs = n_envs
|
||||||
|
self._config = config
|
||||||
|
self._action_dim = int(np.prod(action_space.shape[1:]))
|
||||||
|
self._configured = True
|
||||||
|
|
||||||
|
if isinstance(action_space, Box):
|
||||||
|
self._action_low = torch.tensor(action_space.low[0], device=self.device, dtype=torch.float32)
|
||||||
|
self._action_high = torch.tensor(action_space.high[0], device=self.device, dtype=torch.float32)
|
||||||
|
else:
|
||||||
|
logging.warning(f"Action space is discrete, got {type(action_space)}. ICEMSolver may not work as expected.")
|
||||||
|
self._action_low = None
|
||||||
|
self._action_high = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def n_envs(self) -> int:
|
||||||
|
"""Number of parallel environments."""
|
||||||
|
return self._n_envs
|
||||||
|
|
||||||
|
@property
|
||||||
|
def action_dim(self) -> int:
|
||||||
|
"""Flattened action dimension including action_block grouping."""
|
||||||
|
return self._action_dim * self._config.action_block
|
||||||
|
|
||||||
|
@property
|
||||||
|
def horizon(self) -> int:
|
||||||
|
"""Planning horizon in timesteps."""
|
||||||
|
return self._config.horizon
|
||||||
|
|
||||||
|
def __call__(self, *args: Any, **kwargs: Any) -> dict:
|
||||||
|
"""Make solver callable, forwarding to solve()."""
|
||||||
|
return self.solve(*args, **kwargs)
|
||||||
|
|
||||||
|
def init_action_distrib(
|
||||||
|
self, actions: torch.Tensor | None = None
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Initialize the action distribution parameters (mean and variance)."""
|
||||||
|
var = self.var_scale * torch.ones([self.n_envs, self.horizon, self.action_dim])
|
||||||
|
mean = torch.zeros([self.n_envs, 0, self.action_dim]) if actions is None else actions
|
||||||
|
|
||||||
|
remaining = self.horizon - mean.shape[1]
|
||||||
|
if remaining > 0:
|
||||||
|
device = mean.device
|
||||||
|
new_mean = torch.zeros([self.n_envs, remaining, self.action_dim])
|
||||||
|
mean = torch.cat([mean, new_mean], dim=1).to(device)
|
||||||
|
|
||||||
|
return mean, var
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def solve(
|
||||||
|
self, info_dict: dict, init_action: torch.Tensor | None = None
|
||||||
|
) -> dict:
|
||||||
|
"""Solve the planning problem using improved Cross Entropy Method."""
|
||||||
|
start_time = time.time()
|
||||||
|
outputs = {
|
||||||
|
"costs": [],
|
||||||
|
"mean": [],
|
||||||
|
"var": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
mean, var = self.init_action_distrib(init_action)
|
||||||
|
mean = mean.to(self.device)
|
||||||
|
var = var.to(self.device)
|
||||||
|
|
||||||
|
for start_idx in range(0, self.n_envs, self.batch_size):
|
||||||
|
end_idx = min(start_idx + self.batch_size, self.n_envs)
|
||||||
|
current_bs = end_idx - start_idx
|
||||||
|
|
||||||
|
batch_mean = mean[start_idx:end_idx]
|
||||||
|
batch_var = var[start_idx:end_idx]
|
||||||
|
|
||||||
|
expanded_infos = {}
|
||||||
|
for k, v in info_dict.items():
|
||||||
|
v_batch = v[start_idx:end_idx]
|
||||||
|
if torch.is_tensor(v):
|
||||||
|
v_batch = v_batch.unsqueeze(1)
|
||||||
|
v_batch = v_batch.expand(current_bs, self.num_samples, *v_batch.shape[2:])
|
||||||
|
elif isinstance(v, np.ndarray):
|
||||||
|
v_batch = np.repeat(v_batch[:, None, ...], self.num_samples, axis=1)
|
||||||
|
expanded_infos[k] = v_batch
|
||||||
|
|
||||||
|
prev_topk_candidates = None
|
||||||
|
batch_indices = torch.arange(current_bs, device=self.device).unsqueeze(1).expand(-1, self.topk)
|
||||||
|
|
||||||
|
# Precompute FFT scale for colored noise
|
||||||
|
noise_shape = (current_bs, self.num_samples, self.action_dim, self.horizon)
|
||||||
|
freqs = torch.fft.rfftfreq(self.horizon, device=self.device)
|
||||||
|
freqs[0] = 1.0
|
||||||
|
noise_scale = freqs.pow(-self.noise_beta / 2)
|
||||||
|
noise_scale[0] = noise_scale[1]
|
||||||
|
|
||||||
|
for step in range(self.n_steps):
|
||||||
|
# Colored noise: generate with temporal axis last, then transpose
|
||||||
|
if self.horizon <= 1:
|
||||||
|
noise = torch.randn(noise_shape, generator=self.torch_gen, device=self.device)
|
||||||
|
else:
|
||||||
|
white = torch.randn(noise_shape, generator=self.torch_gen, device=self.device)
|
||||||
|
fft = torch.fft.rfft(white, dim=-1)
|
||||||
|
colored = torch.fft.irfft(fft * noise_scale, n=self.horizon, dim=-1)
|
||||||
|
std = colored.std(dim=-1, keepdim=True).clamp(min=1e-8)
|
||||||
|
noise = colored / std
|
||||||
|
noise = noise.transpose(-1, -2) # -> (bs, num_samples, horizon, action_dim)
|
||||||
|
|
||||||
|
candidates = noise * batch_var.unsqueeze(1) + batch_mean.unsqueeze(1)
|
||||||
|
candidates[:, 0] = batch_mean
|
||||||
|
|
||||||
|
# Inject previous elites
|
||||||
|
if prev_topk_candidates is not None:
|
||||||
|
n_inject = min(self.n_elite_keep, prev_topk_candidates.shape[1])
|
||||||
|
candidates[:, 1:1 + n_inject] = prev_topk_candidates[:, :n_inject]
|
||||||
|
|
||||||
|
# Clip to action bounds
|
||||||
|
if self._action_low is not None:
|
||||||
|
candidates = candidates.clamp(self._action_low, self._action_high)
|
||||||
|
|
||||||
|
current_info = expanded_infos.copy()
|
||||||
|
costs = self.model.get_cost(current_info, candidates)
|
||||||
|
|
||||||
|
assert isinstance(costs, torch.Tensor), f"Expected cost to be a torch.Tensor, got {type(costs)}"
|
||||||
|
assert costs.ndim == 2 and costs.shape[0] == current_bs and costs.shape[1] == self.num_samples, (
|
||||||
|
f"Expected cost to be of shape ({current_bs}, {self.num_samples}), got {costs.shape}"
|
||||||
|
)
|
||||||
|
|
||||||
|
topk_vals, topk_inds = torch.topk(costs, k=self.topk, dim=1, largest=False)
|
||||||
|
topk_candidates = candidates[batch_indices, topk_inds]
|
||||||
|
|
||||||
|
prev_topk_candidates = topk_candidates
|
||||||
|
|
||||||
|
# Momentum update
|
||||||
|
elite_mean = topk_candidates.mean(dim=1)
|
||||||
|
elite_var = topk_candidates.std(dim=1)
|
||||||
|
batch_mean = self.alpha * batch_mean + (1 - self.alpha) * elite_mean
|
||||||
|
batch_var = self.alpha * batch_var + (1 - self.alpha) * elite_var
|
||||||
|
|
||||||
|
final_batch_cost = topk_vals.mean(dim=1).cpu().tolist()
|
||||||
|
|
||||||
|
if self.return_mean:
|
||||||
|
mean[start_idx:end_idx] = batch_mean
|
||||||
|
else:
|
||||||
|
mean[start_idx:end_idx] = topk_candidates[:, 0]
|
||||||
|
|
||||||
|
var[start_idx:end_idx] = batch_var
|
||||||
|
|
||||||
|
outputs["costs"].extend(final_batch_cost)
|
||||||
|
|
||||||
|
outputs["actions"] = mean.detach().cpu()
|
||||||
|
outputs["mean"] = [mean.detach().cpu()]
|
||||||
|
outputs["var"] = [var.detach().cpu()]
|
||||||
|
|
||||||
|
print(f"iCEM solve time: {time.time() - start_time:.4f} seconds")
|
||||||
|
return outputs
|
||||||
@@ -0,0 +1,360 @@
|
|||||||
|
"""Lagrangian solver for stable world model."""
|
||||||
|
|
||||||
|
import time
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import gymnasium as gym
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from gymnasium.spaces import Box
|
||||||
|
from loguru import logger as logging
|
||||||
|
|
||||||
|
from .solver import Costable
|
||||||
|
|
||||||
|
|
||||||
|
class LagrangianSolver(torch.nn.Module):
|
||||||
|
"""Lagrangian solver for stable world model.
|
||||||
|
|
||||||
|
get_cost returns the cost tensor (B, S). If the model also implements get_constraints,
|
||||||
|
it should return the constraint violations (B, S, C), where C is the number of constraints.
|
||||||
|
The constraint_cost should represent the cost of violating the constraints, where the constraint
|
||||||
|
is satisfied when constraint_cost <= 0. The Lagrangian solver will optimize the following objective:
|
||||||
|
|
||||||
|
L = cost + sum_{i=1}^C lambda_i * constraint_cost_i + sum_{i=1}^C rho_i * max(0, constraint_cost_i)^2
|
||||||
|
|
||||||
|
If you want to use equality constraint, you can convert it to two inequality constraints. For example, if you want to enforce constraint_cost_i == 0, you can add two constraints: constraint_cost_i <= 0 and -constraint_cost_i <= 0.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: World model implementing the Costable protocol. Its get_cost() returns
|
||||||
|
a plain cost tensor (B, S). If it also has get_constraints(), that method
|
||||||
|
returns constraints of shape (B, S, C).
|
||||||
|
n_steps: Number of gradient descent steps per outer iteration.
|
||||||
|
n_outer_steps: Number of dual ascent (outer) iterations.
|
||||||
|
batch_size: Number of environments to process in parallel.
|
||||||
|
num_samples: Number of action samples to optimize in parallel.
|
||||||
|
var_scale: Initial variance scale for action perturbations.
|
||||||
|
action_noise: Noise added to actions during optimization.
|
||||||
|
rho_init: Initial penalty coefficient for the quadratic constraint term.
|
||||||
|
rho_max: Maximum value of the penalty coefficient.
|
||||||
|
rho_scale: Multiplicative growth factor for rho after each outer step.
|
||||||
|
persist_multipliers: Whether to warm-start Lagrange multipliers across solve() calls.
|
||||||
|
device: Device for tensor computations.
|
||||||
|
seed: Random seed for reproducibility.
|
||||||
|
optimizer_cls: PyTorch optimizer class to use.
|
||||||
|
optimizer_kwargs: Keyword arguments for the optimizer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: Costable,
|
||||||
|
n_steps: int,
|
||||||
|
n_outer_steps: int = 5,
|
||||||
|
batch_size: int | None = None,
|
||||||
|
num_samples: int = 1,
|
||||||
|
var_scale: float = 1.0,
|
||||||
|
action_noise: float = 0.0,
|
||||||
|
rho_init: float = 1.0,
|
||||||
|
rho_max: float = 1e4,
|
||||||
|
rho_scale: float = 2.0,
|
||||||
|
persist_multipliers: bool = True,
|
||||||
|
device: str | torch.device = 'cpu',
|
||||||
|
seed: int = 1234,
|
||||||
|
optimizer_cls: type[torch.optim.Optimizer] = torch.optim.Adam,
|
||||||
|
optimizer_kwargs: dict | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.model = model
|
||||||
|
self.n_steps = n_steps
|
||||||
|
self.n_outer_steps = n_outer_steps
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.num_samples = num_samples
|
||||||
|
self.var_scale = var_scale
|
||||||
|
self.action_noise = action_noise
|
||||||
|
self.rho_init = rho_init
|
||||||
|
self.rho_max = rho_max
|
||||||
|
self.rho_scale = rho_scale
|
||||||
|
self.persist_multipliers = persist_multipliers
|
||||||
|
self.device = device
|
||||||
|
self.torch_gen = torch.Generator(device=device).manual_seed(seed)
|
||||||
|
self.optimizer_cls = optimizer_cls
|
||||||
|
self.optimizer_kwargs = (
|
||||||
|
optimizer_kwargs if optimizer_kwargs is not None else {'lr': 1.0}
|
||||||
|
)
|
||||||
|
|
||||||
|
self._configured = False
|
||||||
|
self._n_envs = None
|
||||||
|
self._action_dim = None
|
||||||
|
self._config = None
|
||||||
|
self._lambdas: torch.Tensor | None = None # (n_envs, C)
|
||||||
|
|
||||||
|
def configure(
|
||||||
|
self, *, action_space: gym.Space, n_envs: int, config: Any
|
||||||
|
) -> None:
|
||||||
|
"""Configure the solver with environment specifications."""
|
||||||
|
self._action_space = action_space
|
||||||
|
self._n_envs = n_envs
|
||||||
|
self._config = config
|
||||||
|
self._action_dim = int(np.prod(action_space.shape[1:]))
|
||||||
|
self._configured = True
|
||||||
|
|
||||||
|
if not isinstance(action_space, Box):
|
||||||
|
logging.warning(
|
||||||
|
f'Action space is discrete, got {type(action_space)}. LagrangianSolver may not work as expected.'
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def n_envs(self) -> int:
|
||||||
|
"""Number of parallel environments."""
|
||||||
|
return self._n_envs
|
||||||
|
|
||||||
|
@property
|
||||||
|
def action_dim(self) -> int:
|
||||||
|
"""Flattened action dimension including action_block grouping."""
|
||||||
|
return self._action_dim * self._config.action_block
|
||||||
|
|
||||||
|
@property
|
||||||
|
def horizon(self) -> int:
|
||||||
|
"""Planning horizon in timesteps."""
|
||||||
|
return self._config.horizon
|
||||||
|
|
||||||
|
def __call__(self, *args: Any, **kwargs: Any) -> dict:
|
||||||
|
"""Make solver callable, forwarding to solve()."""
|
||||||
|
return self.solve(*args, **kwargs)
|
||||||
|
|
||||||
|
def init_action(self, actions: torch.Tensor | None = None) -> None:
|
||||||
|
"""Initialize the action tensor for optimization."""
|
||||||
|
if actions is None:
|
||||||
|
actions = torch.zeros((self._n_envs, 0, self.action_dim))
|
||||||
|
|
||||||
|
remaining = self.horizon - actions.shape[1]
|
||||||
|
if remaining > 0:
|
||||||
|
new_actions = torch.zeros(self._n_envs, remaining, self.action_dim)
|
||||||
|
actions = torch.cat([actions, new_actions], dim=1).to(self.device)
|
||||||
|
|
||||||
|
actions = actions.unsqueeze(1).repeat_interleave(
|
||||||
|
self.num_samples, dim=1
|
||||||
|
)
|
||||||
|
actions[:, 1:] += (
|
||||||
|
torch.randn(
|
||||||
|
actions[:, 1:].shape,
|
||||||
|
generator=self.torch_gen,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
* self.var_scale
|
||||||
|
)
|
||||||
|
|
||||||
|
if hasattr(self, 'init'):
|
||||||
|
self.init.copy_(actions)
|
||||||
|
else:
|
||||||
|
self.register_parameter('init', torch.nn.Parameter(actions))
|
||||||
|
|
||||||
|
def _init_multipliers(self, num_constraints: int) -> None:
|
||||||
|
"""Lazily initialize Lagrange multipliers to zeros."""
|
||||||
|
self._lambdas = torch.zeros(
|
||||||
|
self._n_envs, num_constraints, device=self.device
|
||||||
|
)
|
||||||
|
|
||||||
|
def _augmented_lagrangian_loss(
|
||||||
|
self,
|
||||||
|
costs: torch.Tensor, # (B, S)
|
||||||
|
constraints: torch.Tensor, # (B, S, C)
|
||||||
|
lambdas_batch: torch.Tensor, # (B, C)
|
||||||
|
rho: float,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Compute the augmented Lagrangian loss.
|
||||||
|
|
||||||
|
L = cost + Σ_i lambda_i * g_i + Σ_i rho * max(0, g_i)^2
|
||||||
|
"""
|
||||||
|
# lambdas_batch: (B, C) -> (B, 1, C) for broadcasting with constraints (B, S, C)
|
||||||
|
linear_penalty = (lambdas_batch.unsqueeze(1) * constraints).sum(
|
||||||
|
dim=-1
|
||||||
|
) # (B, S)
|
||||||
|
quadratic_penalty = rho * F.relu(constraints).pow(2).sum(
|
||||||
|
dim=-1
|
||||||
|
) # (B, S)
|
||||||
|
return (costs + linear_penalty + quadratic_penalty).sum()
|
||||||
|
|
||||||
|
def _update_multipliers(
|
||||||
|
self,
|
||||||
|
constraints: torch.Tensor, # (B, S, C) — detached, no grad
|
||||||
|
lambdas_batch: torch.Tensor, # (B, C)
|
||||||
|
rho: float,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Dual ascent: lambda_i <- max(0, lambda_i + rho * mean_samples(g_i))."""
|
||||||
|
mean_g = constraints.mean(dim=1) # (B, C)
|
||||||
|
return torch.clamp(lambdas_batch + rho * mean_g, min=0.0)
|
||||||
|
|
||||||
|
def solve(
|
||||||
|
self, info_dict: dict, init_action: torch.Tensor | None = None
|
||||||
|
) -> dict:
|
||||||
|
"""Solve the planning problem using augmented Lagrangian gradient descent."""
|
||||||
|
start_time = time.time()
|
||||||
|
outputs: dict = {
|
||||||
|
'cost': [],
|
||||||
|
'constraint_violation': [],
|
||||||
|
'actions': None,
|
||||||
|
'lambdas': None,
|
||||||
|
}
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
self.init_action(init_action)
|
||||||
|
|
||||||
|
if not self.persist_multipliers:
|
||||||
|
self._lambdas = None
|
||||||
|
|
||||||
|
batch_size = (
|
||||||
|
self.batch_size if self.batch_size is not None else self.n_envs
|
||||||
|
)
|
||||||
|
total_envs = self.n_envs
|
||||||
|
batch_top_actions_list = []
|
||||||
|
|
||||||
|
for start_idx in range(0, total_envs, batch_size):
|
||||||
|
end_idx = min(start_idx + batch_size, total_envs)
|
||||||
|
current_bs = end_idx - start_idx
|
||||||
|
|
||||||
|
batch_init = self.init[start_idx:end_idx].clone().detach()
|
||||||
|
batch_init.requires_grad = True
|
||||||
|
|
||||||
|
# Expand info_dict for current batch — same pattern as GradientSolver
|
||||||
|
expanded_infos = {}
|
||||||
|
for k, v in info_dict.items():
|
||||||
|
if torch.is_tensor(v):
|
||||||
|
batch_v = v[start_idx:end_idx]
|
||||||
|
batch_v = batch_v.unsqueeze(1)
|
||||||
|
batch_v = batch_v.expand(
|
||||||
|
current_bs, self.num_samples, *batch_v.shape[2:]
|
||||||
|
)
|
||||||
|
elif isinstance(v, np.ndarray):
|
||||||
|
batch_v = v[start_idx:end_idx]
|
||||||
|
batch_v = np.repeat(
|
||||||
|
batch_v[:, None, ...], self.num_samples, axis=1
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
batch_v = v
|
||||||
|
expanded_infos[k] = batch_v
|
||||||
|
|
||||||
|
rho = self.rho_init
|
||||||
|
batch_cost_history = []
|
||||||
|
costs = None
|
||||||
|
final_constraints = None
|
||||||
|
|
||||||
|
for _outer in range(self.n_outer_steps):
|
||||||
|
# Fresh optimizer each outer step — avoids stale momentum after dual ascent
|
||||||
|
optim = self.optimizer_cls(
|
||||||
|
[batch_init], **self.optimizer_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
for _step in range(self.n_steps):
|
||||||
|
current_info = expanded_infos.copy()
|
||||||
|
costs = self.model.get_cost(current_info, batch_init)
|
||||||
|
constraints = (
|
||||||
|
self.model.get_constraints(
|
||||||
|
expanded_infos.copy(), batch_init
|
||||||
|
)
|
||||||
|
if hasattr(self.model, 'get_constraints')
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(costs, torch.Tensor), (
|
||||||
|
f'Got {type(costs)} cost, expect torch.Tensor'
|
||||||
|
)
|
||||||
|
assert costs.ndim == 2 and costs.shape == (
|
||||||
|
current_bs,
|
||||||
|
self.num_samples,
|
||||||
|
), (
|
||||||
|
f'Cost should be of shape ({current_bs}, {self.num_samples}), got {costs.shape}'
|
||||||
|
)
|
||||||
|
assert costs.requires_grad, (
|
||||||
|
'Cost must requires_grad for LagrangianSolver.'
|
||||||
|
)
|
||||||
|
|
||||||
|
if constraints is not None:
|
||||||
|
assert constraints.ndim == 3 and constraints.shape[
|
||||||
|
:2
|
||||||
|
] == (current_bs, self.num_samples), (
|
||||||
|
f'Constraints should be of shape ({current_bs}, {self.num_samples}, C), got {constraints.shape}'
|
||||||
|
)
|
||||||
|
if self._lambdas is None:
|
||||||
|
self._init_multipliers(constraints.shape[-1])
|
||||||
|
lambdas_batch = self._lambdas[start_idx:end_idx]
|
||||||
|
loss = self._augmented_lagrangian_loss(
|
||||||
|
costs, constraints, lambdas_batch, rho
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
loss = costs.sum()
|
||||||
|
|
||||||
|
loss.backward()
|
||||||
|
optim.step()
|
||||||
|
optim.zero_grad(set_to_none=True)
|
||||||
|
|
||||||
|
if self.action_noise > 0:
|
||||||
|
batch_init.data += (
|
||||||
|
torch.randn(
|
||||||
|
batch_init.shape, generator=self.torch_gen
|
||||||
|
)
|
||||||
|
* self.action_noise
|
||||||
|
)
|
||||||
|
|
||||||
|
batch_cost_history.append(loss.item())
|
||||||
|
|
||||||
|
# Dual ascent after inner loop converges
|
||||||
|
if constraints is not None:
|
||||||
|
with torch.no_grad():
|
||||||
|
final_constraints = self.model.get_constraints(
|
||||||
|
expanded_infos.copy(), batch_init
|
||||||
|
)
|
||||||
|
lambdas_batch = self._update_multipliers(
|
||||||
|
final_constraints, lambdas_batch, rho
|
||||||
|
)
|
||||||
|
self._lambdas[start_idx:end_idx] = lambdas_batch
|
||||||
|
rho = min(self.rho_max, rho * self.rho_scale)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
mean_cost = costs.mean().item()
|
||||||
|
if constraints is not None:
|
||||||
|
viol = F.relu(final_constraints).mean(dim=(0, 1)) # (C,)
|
||||||
|
lam = lambdas_batch.mean(dim=0) # (C,)
|
||||||
|
viol_str = ', '.join(f'{v:.4f}' for v in viol.tolist())
|
||||||
|
lam_str = ', '.join(f'{l:.4f}' for l in lam.tolist())
|
||||||
|
print(
|
||||||
|
f' [outer {_outer+1}/{self.n_outer_steps}] '
|
||||||
|
f'cost={mean_cost:.4f} | '
|
||||||
|
f'constraint_viol=[{viol_str}] | '
|
||||||
|
f'lambdas=[{lam_str}] | '
|
||||||
|
f'rho={rho:.4f}'
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print(
|
||||||
|
f' [outer {_outer+1}/{self.n_outer_steps}] '
|
||||||
|
f'cost={mean_cost:.4f}'
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs['cost'].append(batch_cost_history)
|
||||||
|
|
||||||
|
if final_constraints is not None:
|
||||||
|
outputs['constraint_violation'].append(
|
||||||
|
F.relu(final_constraints).mean().item()
|
||||||
|
)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
self.init[start_idx:end_idx] = batch_init
|
||||||
|
|
||||||
|
top_idx = torch.argsort(costs, dim=1)[:, 0]
|
||||||
|
batch_indices = torch.arange(current_bs)
|
||||||
|
top_actions_batch = batch_init[batch_indices, top_idx]
|
||||||
|
batch_top_actions_list.append(top_actions_batch.detach().cpu())
|
||||||
|
|
||||||
|
outputs['actions'] = torch.cat(batch_top_actions_list, dim=0)
|
||||||
|
outputs['lambdas'] = (
|
||||||
|
self._lambdas.cpu() if self._lambdas is not None else None
|
||||||
|
)
|
||||||
|
|
||||||
|
constraint_info = ''
|
||||||
|
if outputs['constraint_violation']:
|
||||||
|
mean_viol = np.mean(outputs['constraint_violation'])
|
||||||
|
constraint_info = f' | constraint_violation={mean_viol:.4f}'
|
||||||
|
print(
|
||||||
|
f'LagrangianSolver.solve completed in {time.time() - start_time:.4f} seconds{constraint_info}.'
|
||||||
|
)
|
||||||
|
return outputs
|
||||||
@@ -0,0 +1,208 @@
|
|||||||
|
"""Model Predictive Path Integral solver for model-based planning."""
|
||||||
|
|
||||||
|
import time
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import gymnasium as gym
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from gymnasium.spaces import Box
|
||||||
|
from loguru import logger as logging
|
||||||
|
|
||||||
|
from .solver import Costable
|
||||||
|
|
||||||
|
|
||||||
|
class MPPISolver:
|
||||||
|
"""Model Predictive Path Integral solver for action optimization.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: World model implementing the Costable protocol.
|
||||||
|
batch_size: Number of environments to process in parallel.
|
||||||
|
num_samples: Number of action candidates to sample per iteration.
|
||||||
|
var_scale: Initial variance scale for action noise.
|
||||||
|
n_steps: Number of MPPI iterations.
|
||||||
|
topk: Number of elite samples for weighted averaging.
|
||||||
|
temperature: Temperature parameter for softmax weighting.
|
||||||
|
device: Device for tensor computations.
|
||||||
|
seed: Random seed for reproducibility.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: Costable,
|
||||||
|
batch_size: int = 1,
|
||||||
|
num_samples: int = 300,
|
||||||
|
var_scale: float = 1.0,
|
||||||
|
n_steps: int = 30,
|
||||||
|
topk: int = 30,
|
||||||
|
temperature: float = 0.5,
|
||||||
|
device: str | torch.device = "cpu",
|
||||||
|
seed: int = 1234,
|
||||||
|
) -> None:
|
||||||
|
self.model = model
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.num_samples = num_samples
|
||||||
|
self.topk = topk
|
||||||
|
self.var_scale = var_scale
|
||||||
|
self.n_steps = n_steps
|
||||||
|
self.temperature = temperature
|
||||||
|
self.device = device
|
||||||
|
self.torch_gen = torch.Generator(device=device).manual_seed(seed)
|
||||||
|
|
||||||
|
def configure(self, *, action_space: gym.Space, n_envs: int, config: Any) -> None:
|
||||||
|
"""Configure the solver with environment specifications."""
|
||||||
|
self._action_space = action_space
|
||||||
|
self._n_envs = n_envs
|
||||||
|
self._config = config
|
||||||
|
self._action_dim = int(np.prod(action_space.shape[1:]))
|
||||||
|
self._configured = True
|
||||||
|
|
||||||
|
if not isinstance(action_space, Box):
|
||||||
|
logging.warning(
|
||||||
|
f"Action space is discrete, got {type(action_space)}. MPPISolver may not work as expected."
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def n_envs(self) -> int:
|
||||||
|
"""Number of parallel environments."""
|
||||||
|
return self._n_envs
|
||||||
|
|
||||||
|
@property
|
||||||
|
def action_dim(self) -> int:
|
||||||
|
"""Flattened action dimension including action_block grouping."""
|
||||||
|
return self._action_dim * self._config.action_block
|
||||||
|
|
||||||
|
@property
|
||||||
|
def horizon(self) -> int:
|
||||||
|
"""Planning horizon in timesteps."""
|
||||||
|
return self._config.horizon
|
||||||
|
|
||||||
|
def __call__(self, *args: Any, **kwargs: Any) -> dict:
|
||||||
|
"""Make solver callable, forwarding to solve()."""
|
||||||
|
return self.solve(*args, **kwargs)
|
||||||
|
|
||||||
|
def init_action_distrib(
|
||||||
|
self, actions: torch.Tensor | None = None
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Initialize the action distribution parameters (mean and variance)."""
|
||||||
|
var = self.var_scale * torch.ones([self.n_envs, self.horizon, self.action_dim])
|
||||||
|
mean = torch.zeros([self.n_envs, 0, self.action_dim]) if actions is None else actions
|
||||||
|
|
||||||
|
remaining = self.horizon - mean.shape[1]
|
||||||
|
if remaining > 0:
|
||||||
|
device = mean.device
|
||||||
|
new_mean = torch.zeros([self.n_envs, remaining, self.action_dim])
|
||||||
|
mean = torch.cat([mean, new_mean], dim=1).to(device)
|
||||||
|
|
||||||
|
return mean, var
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def solve(
|
||||||
|
self, info_dict: dict, init_action: torch.Tensor | None = None
|
||||||
|
) -> dict:
|
||||||
|
"""Solve the planning problem using MPPI."""
|
||||||
|
start_time = time.time()
|
||||||
|
outputs = {
|
||||||
|
"costs": [],
|
||||||
|
"mean": [],
|
||||||
|
"var": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
# -- initialize the action distribution globally
|
||||||
|
mean, var = self.init_action_distrib(init_action)
|
||||||
|
mean = mean.to(self.device)
|
||||||
|
var = var.to(self.device)
|
||||||
|
|
||||||
|
total_envs = self.n_envs
|
||||||
|
|
||||||
|
# --- Iterate over batches ---
|
||||||
|
for start_idx in range(0, total_envs, self.batch_size):
|
||||||
|
end_idx = min(start_idx + self.batch_size, total_envs)
|
||||||
|
current_bs = end_idx - start_idx
|
||||||
|
|
||||||
|
# Slice Distribution Parameters for current batch
|
||||||
|
batch_mean = mean[start_idx:end_idx]
|
||||||
|
batch_var = var[start_idx:end_idx]
|
||||||
|
|
||||||
|
# Expand Info Dict for current batch (Same as CEM)
|
||||||
|
expanded_infos = {}
|
||||||
|
for k, v in info_dict.items():
|
||||||
|
v_batch = v[start_idx:end_idx]
|
||||||
|
if torch.is_tensor(v):
|
||||||
|
# Add sample dim: (batch, 1, ...)
|
||||||
|
v_batch = v_batch.unsqueeze(1)
|
||||||
|
# Expand: (batch, num_samples, ...)
|
||||||
|
v_batch = v_batch.expand(current_bs, self.num_samples, *v_batch.shape[2:])
|
||||||
|
elif isinstance(v, np.ndarray):
|
||||||
|
v_batch = np.repeat(v_batch[:, None, ...], self.num_samples, axis=1)
|
||||||
|
expanded_infos[k] = v_batch
|
||||||
|
|
||||||
|
# Optimization Loop
|
||||||
|
final_batch_cost = None
|
||||||
|
|
||||||
|
for step in range(self.n_steps):
|
||||||
|
# Sample noise: (Batch, Num_Samples, Horizon, Dim)
|
||||||
|
noise = torch.randn(
|
||||||
|
current_bs,
|
||||||
|
self.num_samples,
|
||||||
|
self.horizon,
|
||||||
|
self.action_dim,
|
||||||
|
generator=self.torch_gen,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
# MPPI Logic: candidates = mean + noise * sigma
|
||||||
|
candidates = batch_mean.unsqueeze(1) + noise * batch_var.unsqueeze(1)
|
||||||
|
|
||||||
|
# Force the first sample to be the current mean (Zero noise)
|
||||||
|
candidates[:, 0] = batch_mean
|
||||||
|
|
||||||
|
# Evaluate candidates
|
||||||
|
costs = self.model.get_cost(expanded_infos, candidates)
|
||||||
|
|
||||||
|
assert isinstance(costs, torch.Tensor), f"Expected cost to be a torch.Tensor, got {type(costs)}"
|
||||||
|
assert costs.ndim == 2 and costs.shape[0] == current_bs and costs.shape[1] == self.num_samples, (
|
||||||
|
f"Expected cost to be of shape ({current_bs}, {self.num_samples}), got {costs.shape}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Select Elites (Optional, based on topk)
|
||||||
|
if self.topk is not None and self.topk < self.num_samples:
|
||||||
|
# topk_vals: (Batch, K), topk_inds: (Batch, K)
|
||||||
|
topk_vals, topk_inds = torch.topk(costs, k=self.topk, dim=1, largest=False)
|
||||||
|
|
||||||
|
# Gather Top-K Candidates
|
||||||
|
batch_indices = torch.arange(current_bs, device=self.device).unsqueeze(1).expand(-1, self.topk)
|
||||||
|
# (Batch, K, Horizon, Dim)
|
||||||
|
relevant_candidates = candidates[batch_indices, topk_inds]
|
||||||
|
relevant_costs = topk_vals
|
||||||
|
else:
|
||||||
|
relevant_candidates = candidates
|
||||||
|
relevant_costs = costs
|
||||||
|
|
||||||
|
# MPPI Weighting: Softmax(-cost / temperature)
|
||||||
|
# Stabilize softmax by subtracting min cost
|
||||||
|
min_cost = relevant_costs.min(dim=1, keepdim=True)[0]
|
||||||
|
scaled_costs = relevant_costs - min_cost
|
||||||
|
weights = torch.softmax(-scaled_costs / self.temperature, dim=1) # (Batch, K)
|
||||||
|
|
||||||
|
# Update Mean: weighted sum of candidates
|
||||||
|
# Reshape weights for broadcasting: (Batch, K, 1, 1)
|
||||||
|
weights_expanded = weights.unsqueeze(-1).unsqueeze(-1)
|
||||||
|
batch_mean = (weights_expanded * relevant_candidates).sum(dim=1)
|
||||||
|
|
||||||
|
# Store average cost of the utilized samples for logging
|
||||||
|
final_batch_cost = relevant_costs.mean(dim=1).cpu().tolist()
|
||||||
|
|
||||||
|
# Write results back to global storage
|
||||||
|
mean[start_idx:end_idx] = batch_mean
|
||||||
|
# We do not update var in standard MPPI
|
||||||
|
|
||||||
|
# Store history/metadata
|
||||||
|
outputs["costs"].extend(final_batch_cost)
|
||||||
|
|
||||||
|
outputs["actions"] = mean.detach().cpu()
|
||||||
|
outputs["mean"] = [mean.detach().cpu()]
|
||||||
|
outputs["var"] = [var.detach().cpu()]
|
||||||
|
|
||||||
|
print(f"MPPI solve time: {time.time() - start_time:.4f} seconds")
|
||||||
|
return outputs
|
||||||
@@ -0,0 +1,80 @@
|
|||||||
|
from typing import Any, Protocol, runtime_checkable
|
||||||
|
|
||||||
|
import gymnasium as gym
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class Costable(Protocol):
|
||||||
|
"""Protocol for world model cost functions."""
|
||||||
|
|
||||||
|
def criterion(self, info_dict: dict, action_candidates: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Compute the cost criterion for action candidates.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
info_dict: Dictionary containing environment state information.
|
||||||
|
action_candidates: Tensor of proposed actions.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tensor of cost values for each action candidate.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def get_cost(self, info_dict: dict, action_candidates: torch.Tensor) -> torch.Tensor: # pragma: no cover
|
||||||
|
"""Compute cost for given action candidates based on info dictionary.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
info_dict: Dictionary containing environment state information.
|
||||||
|
action_candidates: Tensor of proposed actions.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tensor of cost values for each action candidate.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
class Solver(Protocol):
|
||||||
|
"""Protocol for model-based planning solvers."""
|
||||||
|
|
||||||
|
def configure(self, *, action_space: gym.Space, n_envs: int, config: Any) -> None:
|
||||||
|
"""Configure the solver with environment and planning specifications.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
action_space: The action space of the environment.
|
||||||
|
n_envs: Number of parallel environments.
|
||||||
|
config: Planning configuration object.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@property
|
||||||
|
def action_dim(self) -> int:
|
||||||
|
"""Flattened action dimension including action_block grouping."""
|
||||||
|
...
|
||||||
|
|
||||||
|
@property
|
||||||
|
def n_envs(self) -> int:
|
||||||
|
"""Number of parallel environments being planned for."""
|
||||||
|
...
|
||||||
|
|
||||||
|
@property
|
||||||
|
def horizon(self) -> int:
|
||||||
|
"""Planning horizon length in timesteps."""
|
||||||
|
...
|
||||||
|
|
||||||
|
def solve(
|
||||||
|
self,
|
||||||
|
info_dict: dict,
|
||||||
|
init_action: torch.Tensor | None = None,
|
||||||
|
active_mask: torch.Tensor | np.ndarray | None = None,
|
||||||
|
) -> dict:
|
||||||
|
"""Solve the planning optimization problem to find optimal actions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
info_dict: Dictionary containing environment state information.
|
||||||
|
init_action: Optional initial action sequence to warm-start the solver.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary containing optimized actions and other solver-specific info.
|
||||||
|
"""
|
||||||
|
...
|
||||||
1049
.venv/lib/python3.10/site-packages/stable_worldmodel/world.py
Normal file
1049
.venv/lib/python3.10/site-packages/stable_worldmodel/world.py
Normal file
File diff suppressed because it is too large
Load Diff
241
AMD_SETUP.md
Normal file
241
AMD_SETUP.md
Normal file
@@ -0,0 +1,241 @@
|
|||||||
|
# AMD ROCm 环境配置说明
|
||||||
|
|
||||||
|
这份文档记录了在 AMD ROCm 环境下运行 LeWM 的可复现配置,重点是保留
|
||||||
|
`torch.compile` 时的 PyTorch 版本选择。
|
||||||
|
|
||||||
|
目标运行命令:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python eval.py --config-name=pusht.yaml policy=pusht/lewm
|
||||||
|
```
|
||||||
|
|
||||||
|
## 已验证环境
|
||||||
|
|
||||||
|
本次验证通过的环境:
|
||||||
|
|
||||||
|
- Ubuntu 24.04
|
||||||
|
- AMD Radeon PRO W7900D (`gfx1100`)
|
||||||
|
- 系统 ROCm 7.1.1
|
||||||
|
- Python 3.10
|
||||||
|
- `torch==2.10.0+rocm7.1`
|
||||||
|
- `torchvision==0.25.0+rocm7.1`
|
||||||
|
- `triton-rocm==3.6.0`
|
||||||
|
|
||||||
|
注意:`torch==2.12.0+rocm7.1` 可以正常导入,也能识别 GPU,但在本项目里开启
|
||||||
|
`torch.compile` 后会崩溃,错误类似:
|
||||||
|
|
||||||
|
```text
|
||||||
|
HSA_STATUS_ERROR_INVALID_PACKET_FORMAT
|
||||||
|
CUDA error: unspecified launch failure
|
||||||
|
```
|
||||||
|
|
||||||
|
降级到 `torch==2.10.0+rocm7.1` 后,`torch.compile` 路径可以正常跑通。
|
||||||
|
|
||||||
|
## 检查系统 ROCm
|
||||||
|
|
||||||
|
在新 AMD 机器上,先确认系统能识别 GPU:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
rocminfo
|
||||||
|
amd-smi version
|
||||||
|
hipcc --version
|
||||||
|
```
|
||||||
|
|
||||||
|
`rocminfo` 里应该能看到 AMD GPU agent,例如 `gfx1100`。
|
||||||
|
|
||||||
|
## 创建 Python 环境
|
||||||
|
|
||||||
|
使用 `uv` 创建 Python 3.10 虚拟环境:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd /path/to/lewm
|
||||||
|
uv venv --python 3.10 --allow-existing .venv
|
||||||
|
source .venv/bin/activate
|
||||||
|
```
|
||||||
|
|
||||||
|
给 uv 创建的 venv 补上 pip。ROCm 版 PyTorch wheel 很大,如果 uv 解析或下载卡住,
|
||||||
|
用 pip 安装大 wheel 更容易观察进度。
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv pip install pip
|
||||||
|
```
|
||||||
|
|
||||||
|
## 安装 ROCm 版 PyTorch
|
||||||
|
|
||||||
|
安装本项目已验证可用的 ROCm PyTorch 组合:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m pip install --force-reinstall \
|
||||||
|
--index-url https://download.pytorch.org/whl/rocm7.1 \
|
||||||
|
--extra-index-url https://pypi.tuna.tsinghua.edu.cn/simple \
|
||||||
|
"torch==2.10.0" \
|
||||||
|
"torchvision==0.25.0"
|
||||||
|
```
|
||||||
|
|
||||||
|
PyTorch wheel 有数 GB。如果网络慢,不要频繁中断重试,尽量等它下载完成。
|
||||||
|
|
||||||
|
## 安装项目依赖
|
||||||
|
|
||||||
|
普通 Python 包建议走国内 PyPI 镜像:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m pip install \
|
||||||
|
--index-url https://pypi.tuna.tsinghua.edu.cn/simple \
|
||||||
|
"gymnasium[all]==1.2.2" \
|
||||||
|
"stable-baselines3==2.8.0" \
|
||||||
|
"stable-worldmodel[train,env]"
|
||||||
|
```
|
||||||
|
|
||||||
|
然后修正两个容易被 pip 带偏的依赖版本:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m pip install \
|
||||||
|
--index-url https://pypi.tuna.tsinghua.edu.cn/simple \
|
||||||
|
"fsspec==2025.3.0" \
|
||||||
|
"pillow==11.3.0"
|
||||||
|
```
|
||||||
|
|
||||||
|
检查环境:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m pip check
|
||||||
|
python - <<'PY'
|
||||||
|
import torch
|
||||||
|
import torchvision
|
||||||
|
|
||||||
|
print("torch:", torch.__version__)
|
||||||
|
print("hip:", torch.version.hip)
|
||||||
|
print("cuda available:", torch.cuda.is_available())
|
||||||
|
print("device:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else None)
|
||||||
|
print("torchvision:", torchvision.__version__)
|
||||||
|
PY
|
||||||
|
```
|
||||||
|
|
||||||
|
期望看到类似输出:
|
||||||
|
|
||||||
|
```text
|
||||||
|
torch: 2.10.0+rocm7.1
|
||||||
|
cuda available: True
|
||||||
|
torchvision: 0.25.0+rocm7.1
|
||||||
|
```
|
||||||
|
|
||||||
|
## 恢复本仓库里的 stable-worldmodel 修改
|
||||||
|
|
||||||
|
这个仓库把一些本地修改后的 `stable_worldmodel` 文件纳入了 git 管控,路径在:
|
||||||
|
|
||||||
|
```text
|
||||||
|
.venv/lib/python3.10/site-packages/stable_worldmodel/
|
||||||
|
```
|
||||||
|
|
||||||
|
从 PyPI 安装 `stable-worldmodel` 时可能会覆盖这些文件。安装依赖后执行:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git restore -- \
|
||||||
|
.venv/lib/python3.10/site-packages/stable_worldmodel/policy.py \
|
||||||
|
.venv/lib/python3.10/site-packages/stable_worldmodel/world.py \
|
||||||
|
.venv/lib/python3.10/site-packages/stable_worldmodel/solver/cem.py \
|
||||||
|
.venv/lib/python3.10/site-packages/stable_worldmodel/solver/gd.py
|
||||||
|
```
|
||||||
|
|
||||||
|
然后确认没有意外修改:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git status --short
|
||||||
|
```
|
||||||
|
|
||||||
|
## 数据和 checkpoint 路径
|
||||||
|
|
||||||
|
`eval.py` 会从 `$STABLEWM_HOME` 里找数据和 checkpoint。
|
||||||
|
|
||||||
|
PushT 评估至少需要:
|
||||||
|
|
||||||
|
```text
|
||||||
|
$STABLEWM_HOME/pusht_expert_train.h5
|
||||||
|
$STABLEWM_HOME/pusht/lewm_object.ckpt
|
||||||
|
```
|
||||||
|
|
||||||
|
例如本机使用:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export STABLEWM_HOME=/mnt/ASC1637/stablewm
|
||||||
|
```
|
||||||
|
|
||||||
|
如果没有正确设置,运行时会报找不到 `pusht_expert_train.h5`。
|
||||||
|
|
||||||
|
## 运行评估
|
||||||
|
|
||||||
|
默认 PushT 评估,保留 `torch.compile`:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export STABLEWM_HOME=/path/to/stablewm
|
||||||
|
python eval.py --config-name=pusht.yaml policy=pusht/lewm
|
||||||
|
```
|
||||||
|
|
||||||
|
快速 smoke test:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export STABLEWM_HOME=/path/to/stablewm
|
||||||
|
python eval.py --config-name=pusht.yaml policy=pusht/lewm \
|
||||||
|
eval.num_eval=1 \
|
||||||
|
world.num_envs=1 \
|
||||||
|
output.filename=/tmp/lewm_smoke_test.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
smoke test 应该能正常结束,并打印类似:
|
||||||
|
|
||||||
|
```text
|
||||||
|
{'success_rate': 100.0, ...}
|
||||||
|
```
|
||||||
|
|
||||||
|
## 常见问题
|
||||||
|
|
||||||
|
### `HSA_STATUS_ERROR_INVALID_PACKET_FORMAT`
|
||||||
|
|
||||||
|
如果开启 `torch.compile` 时出现这个错误,先检查 torch 版本:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -c "import torch; print(torch.__version__, torch.version.hip)"
|
||||||
|
```
|
||||||
|
|
||||||
|
如果是 `2.12.0+rocm7.1`,建议降级到本项目验证通过的组合:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m pip install --force-reinstall \
|
||||||
|
--index-url https://download.pytorch.org/whl/rocm7.1 \
|
||||||
|
--extra-index-url https://pypi.tuna.tsinghua.edu.cn/simple \
|
||||||
|
"torch==2.10.0" \
|
||||||
|
"torchvision==0.25.0"
|
||||||
|
```
|
||||||
|
|
||||||
|
### 找不到 `pusht_expert_train.h5`
|
||||||
|
|
||||||
|
设置 `STABLEWM_HOME` 到包含数据和 checkpoint 的目录:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export STABLEWM_HOME=/path/to/stablewm
|
||||||
|
```
|
||||||
|
|
||||||
|
### pip 尝试构建旧版 `gym==0.21`
|
||||||
|
|
||||||
|
这是依赖解析回退导致的。先显式安装兼容版本:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m pip install \
|
||||||
|
--index-url https://pypi.tuna.tsinghua.edu.cn/simple \
|
||||||
|
"gymnasium[all]==1.2.2" \
|
||||||
|
"stable-baselines3==2.8.0"
|
||||||
|
```
|
||||||
|
|
||||||
|
### uv 或 pip 访问海外源很慢
|
||||||
|
|
||||||
|
普通 Python 包使用国内 PyPI 镜像:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
--index-url https://pypi.tuna.tsinghua.edu.cn/simple
|
||||||
|
```
|
||||||
|
|
||||||
|
PyTorch ROCm wheel 继续使用 PyTorch 官方 ROCm 源:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
--index-url https://download.pytorch.org/whl/rocm7.1
|
||||||
|
```
|
||||||
27
README.md
27
README.md
@@ -84,6 +84,33 @@ python eval.py --config-name=pusht.yaml policy=pusht/lewm
|
|||||||
python eval.py --config-name=pusht.yaml policy=pusht/lewm_object.ckpt
|
python eval.py --config-name=pusht.yaml policy=pusht/lewm_object.ckpt
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Profiling
|
||||||
|
|
||||||
|
`eval.py` now supports optional inference profiling with PyTorch's native profiler.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python eval.py --config-name=pusht.yaml policy=pusht/lewm \
|
||||||
|
inference_precision=bf16 \
|
||||||
|
+profile.enabled=true \
|
||||||
|
+profile.with_stack=true \
|
||||||
|
+profile.record_shapes=true \
|
||||||
|
+profile.profile_memory=true
|
||||||
|
```
|
||||||
|
|
||||||
|
Supported inference precision modes:
|
||||||
|
- `inference_precision=fp32`
|
||||||
|
- `inference_precision=bf16`
|
||||||
|
- `inference_precision=fp16`
|
||||||
|
|
||||||
|
Outputs are written next to the evaluation results:
|
||||||
|
- `torch_profile/key_averages.txt` for the aggregated operator table
|
||||||
|
- `torch_profile/trace.json` for Chrome tracing
|
||||||
|
- TensorBoard trace files under `torch_profile/`
|
||||||
|
|
||||||
|
The trace includes custom regions such as `eval.world_evaluate_from_dataset`, `lewm.get_cost`, `lewm.rollout`, and `lewm.predict` to make the planning path easier to inspect.
|
||||||
|
|
||||||
## Pretrained Checkpoints
|
## Pretrained Checkpoints
|
||||||
|
|
||||||
Pre-trained checkpoints are available on [Google Drive](https://drive.google.com/drive/folders/1r31os0d4-rR0mdHc7OlY_e5nh3XT4r4e). Download the checkpoint archive and place the extracted files under `$STABLEWM_HOME/`.
|
Pre-trained checkpoints are available on [Google Drive](https://drive.google.com/drive/folders/1r31os0d4-rR0mdHc7OlY_e5nh3XT4r4e). Download the checkpoint archive and place the extracted files under `$STABLEWM_HOME/`.
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ dataset:
|
|||||||
|
|
||||||
seed: 42
|
seed: 42
|
||||||
policy: random # ckpt name or random
|
policy: random # ckpt name or random
|
||||||
|
inference_precision: fp16
|
||||||
|
|
||||||
plan_config:
|
plan_config:
|
||||||
horizon: 5
|
horizon: 5
|
||||||
@@ -36,6 +37,10 @@ eval:
|
|||||||
goal_offset_steps: 25
|
goal_offset_steps: 25
|
||||||
eval_budget: 50
|
eval_budget: 50
|
||||||
img_size: 224
|
img_size: 224
|
||||||
|
save_video: false
|
||||||
|
compile_warmup:
|
||||||
|
enabled: true
|
||||||
|
num_eval: 1
|
||||||
dataset_name: ogbench/cube_single_expert
|
dataset_name: ogbench/cube_single_expert
|
||||||
callables:
|
callables:
|
||||||
# -- set state
|
# -- set state
|
||||||
@@ -56,6 +61,21 @@ eval:
|
|||||||
target_quat:
|
target_quat:
|
||||||
value: goal_privileged_block_0_quat
|
value: goal_privileged_block_0_quat
|
||||||
|
|
||||||
|
multi_node:
|
||||||
|
enabled: false
|
||||||
|
backend: gloo
|
||||||
|
rank_env: RANK
|
||||||
|
world_size_env: WORLD_SIZE
|
||||||
|
local_rank_env: LOCAL_RANK
|
||||||
|
aggregate_results: true
|
||||||
|
sync_before_return: false
|
||||||
|
destroy_process_group: true
|
||||||
|
shard_strategy: round_robin
|
||||||
|
|
||||||
|
preload_wait:
|
||||||
|
enabled: false
|
||||||
|
file: /tmp/lewm_preload_start
|
||||||
|
poll_interval: 1.0
|
||||||
|
|
||||||
output:
|
output:
|
||||||
filename: ogb_cube_results.txt
|
filename: ogb_cube_results.txt
|
||||||
|
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ dataset:
|
|||||||
|
|
||||||
seed: 42
|
seed: 42
|
||||||
policy: random # ckpt name or random
|
policy: random # ckpt name or random
|
||||||
|
inference_precision: fp16
|
||||||
|
|
||||||
plan_config:
|
plan_config:
|
||||||
horizon: 5
|
horizon: 5
|
||||||
@@ -31,6 +32,10 @@ eval:
|
|||||||
goal_offset_steps: 25
|
goal_offset_steps: 25
|
||||||
eval_budget: 50
|
eval_budget: 50
|
||||||
img_size: 224
|
img_size: 224
|
||||||
|
save_video: false
|
||||||
|
compile_warmup:
|
||||||
|
enabled: true
|
||||||
|
num_eval: 1
|
||||||
dataset_name: pusht_expert_train
|
dataset_name: pusht_expert_train
|
||||||
callables:
|
callables:
|
||||||
# -- set state
|
# -- set state
|
||||||
@@ -43,6 +48,22 @@ eval:
|
|||||||
args:
|
args:
|
||||||
goal_state:
|
goal_state:
|
||||||
value: goal_state
|
value: goal_state
|
||||||
|
|
||||||
|
multi_node:
|
||||||
|
enabled: false
|
||||||
|
backend: gloo
|
||||||
|
rank_env: RANK
|
||||||
|
world_size_env: WORLD_SIZE
|
||||||
|
local_rank_env: LOCAL_RANK
|
||||||
|
aggregate_results: true
|
||||||
|
sync_before_return: false
|
||||||
|
destroy_process_group: true
|
||||||
|
shard_strategy: round_robin
|
||||||
|
|
||||||
|
preload_wait:
|
||||||
|
enabled: false
|
||||||
|
file: /tmp/lewm_preload_start
|
||||||
|
poll_interval: 1.0
|
||||||
|
|
||||||
output:
|
output:
|
||||||
filename: pusht_results.txt
|
filename: pusht_results.txt
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ dataset:
|
|||||||
|
|
||||||
seed: 42
|
seed: 42
|
||||||
policy: random # ckpt name or random
|
policy: random # ckpt name or random
|
||||||
|
inference_precision: fp16
|
||||||
|
|
||||||
plan_config:
|
plan_config:
|
||||||
horizon: 5
|
horizon: 5
|
||||||
@@ -30,6 +31,10 @@ eval:
|
|||||||
goal_offset_steps: 25
|
goal_offset_steps: 25
|
||||||
eval_budget: 50
|
eval_budget: 50
|
||||||
img_size: 224
|
img_size: 224
|
||||||
|
save_video: false
|
||||||
|
compile_warmup:
|
||||||
|
enabled: true
|
||||||
|
num_eval: 1
|
||||||
dataset_name: dmc/reacher_random
|
dataset_name: dmc/reacher_random
|
||||||
callables:
|
callables:
|
||||||
# -- set state
|
# -- set state
|
||||||
@@ -45,6 +50,21 @@ eval:
|
|||||||
target_qpos:
|
target_qpos:
|
||||||
value: goal_qpos
|
value: goal_qpos
|
||||||
|
|
||||||
|
multi_node:
|
||||||
|
enabled: false
|
||||||
|
backend: gloo
|
||||||
|
rank_env: RANK
|
||||||
|
world_size_env: WORLD_SIZE
|
||||||
|
local_rank_env: LOCAL_RANK
|
||||||
|
aggregate_results: true
|
||||||
|
sync_before_return: false
|
||||||
|
destroy_process_group: true
|
||||||
|
shard_strategy: round_robin
|
||||||
|
|
||||||
|
preload_wait:
|
||||||
|
enabled: false
|
||||||
|
file: /tmp/lewm_preload_start
|
||||||
|
poll_interval: 1.0
|
||||||
|
|
||||||
output:
|
output:
|
||||||
filename: dmc_results.txt
|
filename: dmc_results.txt
|
||||||
|
|
||||||
|
|||||||
@@ -1,9 +1,10 @@
|
|||||||
_target_: stable_worldmodel.solver.CEMSolver
|
_target_: stable_worldmodel.solver.CEMSolver
|
||||||
model: ???
|
model: ???
|
||||||
batch_size: 1
|
batch_size: 16
|
||||||
|
# Original defaults: num_samples=300, n_steps=30, topk=30, batch_size=8.
|
||||||
num_samples: 300
|
num_samples: 300
|
||||||
var_scale: 1.0
|
var_scale: 1.0
|
||||||
n_steps: 30
|
n_steps: 30
|
||||||
topk: 30
|
topk: 8
|
||||||
device: "cuda"
|
device: "cuda"
|
||||||
seed: ${seed}
|
seed: ${seed}
|
||||||
|
|||||||
14
config/eval/solver/gradient.yaml
Normal file
14
config/eval/solver/gradient.yaml
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
_target_: stable_worldmodel.solver.GradientSolver
|
||||||
|
model: ???
|
||||||
|
# Original adam.yaml reference: n_steps=30, num_samples=100, batch_size=1, lr=0.1.
|
||||||
|
n_steps: 90
|
||||||
|
batch_size: 100
|
||||||
|
num_samples: 1
|
||||||
|
action_noise: 0
|
||||||
|
device: "cuda"
|
||||||
|
seed: ${seed}
|
||||||
|
optimizer_cls:
|
||||||
|
_target_: hydra.utils.get_class
|
||||||
|
path: torch.optim.AdamW
|
||||||
|
optimizer_kwargs:
|
||||||
|
lr: 0.075
|
||||||
@@ -12,6 +12,7 @@ world:
|
|||||||
|
|
||||||
seed: 42
|
seed: 42
|
||||||
policy: random # ckpt name or random
|
policy: random # ckpt name or random
|
||||||
|
inference_precision: fp16
|
||||||
|
|
||||||
dataset:
|
dataset:
|
||||||
stats: ${eval.dataset_name}
|
stats: ${eval.dataset_name}
|
||||||
@@ -30,6 +31,10 @@ eval:
|
|||||||
goal_offset_steps: 25
|
goal_offset_steps: 25
|
||||||
eval_budget: 50
|
eval_budget: 50
|
||||||
img_size: 224
|
img_size: 224
|
||||||
|
save_video: false
|
||||||
|
compile_warmup:
|
||||||
|
enabled: true
|
||||||
|
num_eval: 1
|
||||||
dataset_name: tworoom
|
dataset_name: tworoom
|
||||||
callables:
|
callables:
|
||||||
# -- set state
|
# -- set state
|
||||||
@@ -43,5 +48,21 @@ eval:
|
|||||||
goal_state:
|
goal_state:
|
||||||
value: goal_proprio
|
value: goal_proprio
|
||||||
|
|
||||||
|
multi_node:
|
||||||
|
enabled: false
|
||||||
|
backend: gloo
|
||||||
|
rank_env: RANK
|
||||||
|
world_size_env: WORLD_SIZE
|
||||||
|
local_rank_env: LOCAL_RANK
|
||||||
|
aggregate_results: true
|
||||||
|
sync_before_return: false
|
||||||
|
destroy_process_group: true
|
||||||
|
shard_strategy: round_robin
|
||||||
|
|
||||||
|
preload_wait:
|
||||||
|
enabled: false
|
||||||
|
file: /tmp/lewm_preload_start
|
||||||
|
poll_interval: 1.0
|
||||||
|
|
||||||
output:
|
output:
|
||||||
filename: tworoom_results.txt
|
filename: tworoom_results.txt
|
||||||
|
|||||||
836
eval.py
836
eval.py
@@ -2,8 +2,12 @@ import os
|
|||||||
|
|
||||||
os.environ["MUJOCO_GL"] = "egl"
|
os.environ["MUJOCO_GL"] = "egl"
|
||||||
|
|
||||||
|
import multiprocessing as mp
|
||||||
import time
|
import time
|
||||||
|
import traceback
|
||||||
|
from contextlib import nullcontext
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
import tempfile
|
||||||
|
|
||||||
import hydra
|
import hydra
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -46,76 +50,331 @@ def get_dataset(cfg, dataset_name):
|
|||||||
)
|
)
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
@hydra.main(version_base=None, config_path="./config/eval", config_name="pusht")
|
|
||||||
def run(cfg: DictConfig):
|
|
||||||
"""Run evaluation of dinowm vs random policy."""
|
|
||||||
assert (
|
|
||||||
cfg.plan_config.horizon * cfg.plan_config.action_block <= cfg.eval.eval_budget
|
|
||||||
), "Planning horizon must be smaller than or equal to eval_budget"
|
|
||||||
|
|
||||||
# create world environment
|
def get_profile_cfg(cfg):
|
||||||
cfg.world.max_episode_steps = 2 * cfg.eval.eval_budget
|
profile_cfg = {
|
||||||
world = swm.World(**cfg.world, image_shape=(224, 224))
|
"enabled": False,
|
||||||
|
"trace_dirname": "torch_profile",
|
||||||
|
"record_shapes": True,
|
||||||
|
"profile_memory": True,
|
||||||
|
"with_stack": False,
|
||||||
|
"with_flops": False,
|
||||||
|
"row_limit": 40,
|
||||||
|
"worker_name": "eval",
|
||||||
|
"export_chrome_trace": True,
|
||||||
|
"export_tensorboard": True,
|
||||||
|
}
|
||||||
|
cfg_profile = cfg.get("profile")
|
||||||
|
if cfg_profile is not None:
|
||||||
|
profile_cfg.update(OmegaConf.to_container(cfg_profile, resolve=True))
|
||||||
|
return profile_cfg
|
||||||
|
|
||||||
# create the transform
|
|
||||||
transform = {
|
def get_compile_cfg(cfg):
|
||||||
"pixels": img_transform(cfg),
|
compile_cfg = {
|
||||||
"goal": img_transform(cfg),
|
"enabled": True,
|
||||||
|
"target": "predictor",
|
||||||
|
"mode": "reduce-overhead",
|
||||||
|
"fullgraph": False,
|
||||||
|
"dynamic": False,
|
||||||
|
"cuda_only": True,
|
||||||
|
}
|
||||||
|
cfg_compile = cfg.get("compile")
|
||||||
|
if cfg_compile is not None:
|
||||||
|
compile_cfg.update(OmegaConf.to_container(cfg_compile, resolve=True))
|
||||||
|
return compile_cfg
|
||||||
|
|
||||||
|
|
||||||
|
def get_compile_warmup_cfg(cfg):
|
||||||
|
warmup_cfg = {
|
||||||
|
"enabled": True,
|
||||||
|
"num_eval": 1,
|
||||||
|
}
|
||||||
|
cfg_warmup = cfg.get("compile_warmup")
|
||||||
|
if cfg_warmup is None:
|
||||||
|
cfg_eval = cfg.get("eval")
|
||||||
|
if cfg_eval is not None:
|
||||||
|
cfg_warmup = cfg_eval.get("compile_warmup")
|
||||||
|
if cfg_warmup is not None:
|
||||||
|
warmup_cfg.update(OmegaConf.to_container(cfg_warmup, resolve=True))
|
||||||
|
return warmup_cfg
|
||||||
|
|
||||||
|
|
||||||
|
def get_preload_wait_cfg(cfg):
|
||||||
|
preload_cfg = {
|
||||||
|
"enabled": False,
|
||||||
|
"file": "/tmp/lewm_preload_start",
|
||||||
|
"poll_interval": 1.0,
|
||||||
|
}
|
||||||
|
cfg_preload = cfg.get("preload_wait")
|
||||||
|
if cfg_preload is not None:
|
||||||
|
preload_cfg.update(OmegaConf.to_container(cfg_preload, resolve=True))
|
||||||
|
return preload_cfg
|
||||||
|
|
||||||
|
|
||||||
|
def wait_for_preload_signal(cfg, rank=0):
|
||||||
|
preload_cfg = get_preload_wait_cfg(cfg)
|
||||||
|
if not preload_cfg["enabled"]:
|
||||||
|
return
|
||||||
|
|
||||||
|
dist_ready = (
|
||||||
|
torch.distributed.is_available()
|
||||||
|
and torch.distributed.is_initialized()
|
||||||
|
)
|
||||||
|
if dist_ready:
|
||||||
|
torch.distributed.barrier()
|
||||||
|
|
||||||
|
signal_path = Path(str(preload_cfg["file"])).expanduser()
|
||||||
|
poll_interval = float(preload_cfg["poll_interval"])
|
||||||
|
if rank == 0:
|
||||||
|
print(
|
||||||
|
"Preload ready. Create this file to start evaluation: "
|
||||||
|
f"{signal_path}",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
while not signal_path.exists():
|
||||||
|
time.sleep(poll_interval)
|
||||||
|
print("Preload start signal received. Starting evaluation.", flush=True)
|
||||||
|
|
||||||
|
if dist_ready:
|
||||||
|
torch.distributed.barrier()
|
||||||
|
|
||||||
|
|
||||||
|
def maybe_compile_inference_target(model, cfg, device):
|
||||||
|
compile_cfg = get_compile_cfg(cfg)
|
||||||
|
compile_target = "disabled"
|
||||||
|
|
||||||
|
if not compile_cfg["enabled"]:
|
||||||
|
return model, compile_cfg, compile_target
|
||||||
|
|
||||||
|
if not hasattr(torch, "compile"):
|
||||||
|
print("torch.compile is unavailable, skipping inference compilation.")
|
||||||
|
return model, compile_cfg, compile_target
|
||||||
|
|
||||||
|
if compile_cfg["cuda_only"] and not str(device).startswith("cuda"):
|
||||||
|
print("Skipping torch.compile because compile.cuda_only=true and device is not CUDA.")
|
||||||
|
return model, compile_cfg, compile_target
|
||||||
|
|
||||||
|
target = str(compile_cfg["target"]).lower()
|
||||||
|
compile_kwargs = {
|
||||||
|
"mode": compile_cfg["mode"],
|
||||||
|
"fullgraph": compile_cfg["fullgraph"],
|
||||||
|
"dynamic": compile_cfg["dynamic"],
|
||||||
}
|
}
|
||||||
|
|
||||||
dataset = get_dataset(cfg, cfg.eval.dataset_name)
|
if target == "predictor":
|
||||||
stats_dataset = dataset # get_dataset(cfg, cfg.dataset.stats)
|
if not hasattr(model, "predictor"):
|
||||||
col_name = "episode_idx" if "episode_idx" in dataset.column_names else "ep_idx"
|
print("Requested compile target 'predictor' is unavailable on the model.")
|
||||||
ep_indices, _ = np.unique(stats_dataset.get_col_data(col_name), return_index=True)
|
return model, compile_cfg, compile_target
|
||||||
|
model.predictor = torch.compile(model.predictor, **compile_kwargs)
|
||||||
|
compile_target = "predictor"
|
||||||
|
elif target == "predict":
|
||||||
|
if not hasattr(model, "predict"):
|
||||||
|
print("Requested compile target 'predict' is unavailable on the model.")
|
||||||
|
return model, compile_cfg, compile_target
|
||||||
|
model.predict = torch.compile(model.predict, **compile_kwargs)
|
||||||
|
compile_target = "predict"
|
||||||
|
else:
|
||||||
|
print(
|
||||||
|
f"Unsupported compile.target={target}. Expected one of: predictor, predict."
|
||||||
|
)
|
||||||
|
|
||||||
|
return model, compile_cfg, compile_target
|
||||||
|
|
||||||
|
|
||||||
|
def get_inference_context(cfg, device):
|
||||||
|
precision = str(cfg.get("inference_precision", "fp32")).lower()
|
||||||
|
device_type = "cuda" if device.startswith("cuda") else "cpu"
|
||||||
|
|
||||||
|
if precision == "fp32":
|
||||||
|
return nullcontext(), "fp32"
|
||||||
|
|
||||||
|
if precision in {"bf16", "bfloat16"}:
|
||||||
|
return (
|
||||||
|
torch.autocast(device_type=device_type, dtype=torch.bfloat16),
|
||||||
|
"bf16",
|
||||||
|
)
|
||||||
|
|
||||||
|
if precision in {"fp16", "float16"}:
|
||||||
|
if device_type != "cuda":
|
||||||
|
print("fp16 inference is only supported on CUDA, falling back to fp32.")
|
||||||
|
return nullcontext(), "fp32"
|
||||||
|
return (
|
||||||
|
torch.autocast(device_type=device_type, dtype=torch.float16),
|
||||||
|
"fp16",
|
||||||
|
)
|
||||||
|
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported inference_precision={precision}. Expected one of: fp32, bf16, fp16."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_eval_grad_context(solver=None):
|
||||||
|
if isinstance(solver, swm.solver.GradientSolver):
|
||||||
|
return torch.enable_grad()
|
||||||
|
return torch.inference_mode()
|
||||||
|
|
||||||
|
|
||||||
|
def make_profiler(cfg, results_path):
|
||||||
|
profile_cfg = get_profile_cfg(cfg)
|
||||||
|
if not profile_cfg["enabled"]:
|
||||||
|
return nullcontext(), None, profile_cfg
|
||||||
|
|
||||||
|
activities = [torch.profiler.ProfilerActivity.CPU]
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
activities.append(torch.profiler.ProfilerActivity.CUDA)
|
||||||
|
|
||||||
|
profile_dir = results_path / profile_cfg["trace_dirname"]
|
||||||
|
profile_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
profiler = torch.profiler.profile(
|
||||||
|
activities=activities,
|
||||||
|
record_shapes=profile_cfg["record_shapes"],
|
||||||
|
profile_memory=profile_cfg["profile_memory"],
|
||||||
|
with_stack=profile_cfg["with_stack"],
|
||||||
|
with_flops=profile_cfg["with_flops"],
|
||||||
|
)
|
||||||
|
return profiler, profile_dir, profile_cfg
|
||||||
|
|
||||||
|
|
||||||
|
def dump_profiler_results(profiler, profile_dir, profile_cfg):
|
||||||
|
if profiler is None or profile_dir is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
has_cuda = torch.cuda.is_available()
|
||||||
|
table = profiler.key_averages().table(
|
||||||
|
sort_by="self_cuda_time_total" if has_cuda else "self_cpu_time_total",
|
||||||
|
row_limit=profile_cfg["row_limit"],
|
||||||
|
)
|
||||||
|
|
||||||
|
summary_path = profile_dir / "key_averages.txt"
|
||||||
|
summary_path.write_text(table)
|
||||||
|
|
||||||
|
if profile_cfg["export_tensorboard"]:
|
||||||
|
trace_handler = torch.profiler.tensorboard_trace_handler(
|
||||||
|
str(profile_dir), worker_name=profile_cfg["worker_name"]
|
||||||
|
)
|
||||||
|
trace_handler(profiler)
|
||||||
|
elif profile_cfg["export_chrome_trace"]:
|
||||||
|
profiler.export_chrome_trace(str(profile_dir / "trace.json"))
|
||||||
|
|
||||||
|
return summary_path
|
||||||
|
|
||||||
|
|
||||||
|
def get_multi_gpu_cfg(cfg):
|
||||||
|
multi_gpu_cfg = {
|
||||||
|
"enabled": False,
|
||||||
|
"devices": None,
|
||||||
|
"start_method": "spawn",
|
||||||
|
}
|
||||||
|
cfg_multi_gpu = cfg.get("multi_gpu")
|
||||||
|
if cfg_multi_gpu is not None:
|
||||||
|
multi_gpu_cfg.update(OmegaConf.to_container(cfg_multi_gpu, resolve=True))
|
||||||
|
return multi_gpu_cfg
|
||||||
|
|
||||||
|
|
||||||
|
def get_multi_node_cfg(cfg):
|
||||||
|
multi_node_cfg = {
|
||||||
|
"enabled": False,
|
||||||
|
"backend": "gloo",
|
||||||
|
"rank_env": "RANK",
|
||||||
|
"world_size_env": "WORLD_SIZE",
|
||||||
|
"local_rank_env": "LOCAL_RANK",
|
||||||
|
"output_mode": "single",
|
||||||
|
"aggregate_results": True,
|
||||||
|
"sync_before_return": False,
|
||||||
|
"destroy_process_group": True,
|
||||||
|
"shard_strategy": "round_robin",
|
||||||
|
}
|
||||||
|
cfg_multi_node = cfg.get("multi_node")
|
||||||
|
if cfg_multi_node is not None:
|
||||||
|
multi_node_cfg.update(OmegaConf.to_container(cfg_multi_node, resolve=True))
|
||||||
|
return multi_node_cfg
|
||||||
|
|
||||||
|
|
||||||
|
def get_dist_env(name, default=None):
|
||||||
|
value = os.environ.get(name, default)
|
||||||
|
if value is None:
|
||||||
|
return None
|
||||||
|
return int(value)
|
||||||
|
|
||||||
|
|
||||||
|
def get_rank_context(cfg):
|
||||||
|
multi_node_cfg = get_multi_node_cfg(cfg)
|
||||||
|
if not multi_node_cfg["enabled"]:
|
||||||
|
return 0, 1, 0
|
||||||
|
|
||||||
|
rank = get_dist_env(multi_node_cfg["rank_env"])
|
||||||
|
world_size = get_dist_env(multi_node_cfg["world_size_env"])
|
||||||
|
local_rank = get_dist_env(multi_node_cfg["local_rank_env"], 0)
|
||||||
|
|
||||||
|
if rank is None or world_size is None:
|
||||||
|
raise ValueError(
|
||||||
|
"multi_node.enabled=true requires torchrun env vars RANK and WORLD_SIZE"
|
||||||
|
)
|
||||||
|
if world_size < 1:
|
||||||
|
raise ValueError("WORLD_SIZE must be >= 1")
|
||||||
|
if rank < 0 or rank >= world_size:
|
||||||
|
raise ValueError("RANK must be in [0, WORLD_SIZE)")
|
||||||
|
return rank, world_size, local_rank
|
||||||
|
|
||||||
|
|
||||||
|
def all_gather_eval_result(result):
|
||||||
|
world_size = torch.distributed.get_world_size()
|
||||||
|
payload = [None for _ in range(world_size)]
|
||||||
|
torch.distributed.all_gather_object(payload, result)
|
||||||
|
return payload
|
||||||
|
|
||||||
|
|
||||||
|
def finalize_multi_node_process_group(cfg):
|
||||||
|
multi_node_cfg = get_multi_node_cfg(cfg)
|
||||||
|
if not multi_node_cfg["destroy_process_group"]:
|
||||||
|
return
|
||||||
|
if torch.distributed.is_available() and torch.distributed.is_initialized():
|
||||||
|
torch.distributed.destroy_process_group()
|
||||||
|
|
||||||
|
|
||||||
|
def get_rank_result_path(output_dir: Path, cfg: DictConfig, rank: int) -> Path:
|
||||||
|
filename = str(cfg.output.filename)
|
||||||
|
if rank == 0:
|
||||||
|
return output_dir / filename
|
||||||
|
|
||||||
|
suffix = Path(filename).suffix
|
||||||
|
stem = Path(filename).stem
|
||||||
|
if suffix:
|
||||||
|
ranked_filename = f"{stem}.rank{rank}{suffix}"
|
||||||
|
else:
|
||||||
|
ranked_filename = f"{filename}.rank{rank}"
|
||||||
|
return output_dir / ranked_filename
|
||||||
|
|
||||||
|
|
||||||
|
def build_process(cfg, dataset):
|
||||||
process = {}
|
process = {}
|
||||||
for col in cfg.dataset.keys_to_cache:
|
for col in cfg.dataset.keys_to_cache:
|
||||||
if col in ["pixels"]:
|
if col in ["pixels"]:
|
||||||
continue
|
continue
|
||||||
processor = preprocessing.StandardScaler()
|
processor = preprocessing.StandardScaler()
|
||||||
col_data = stats_dataset.get_col_data(col)
|
col_data = dataset.get_col_data(col)
|
||||||
col_data = col_data[~np.isnan(col_data).any(axis=1)]
|
col_data = col_data[~np.isnan(col_data).any(axis=1)]
|
||||||
processor.fit(col_data)
|
processor.fit(col_data)
|
||||||
process[col] = processor
|
process[col] = processor
|
||||||
|
|
||||||
if col != "action":
|
if col != "action":
|
||||||
process[f"goal_{col}"] = process[col]
|
process[f"goal_{col}"] = process[col]
|
||||||
|
return process
|
||||||
|
|
||||||
# -- run evaluation
|
|
||||||
policy = cfg.get("policy", "random")
|
|
||||||
|
|
||||||
if policy != "random":
|
def sample_eval_cases(cfg, dataset):
|
||||||
model = swm.policy.AutoCostModel(cfg.policy)
|
stats_dataset = dataset
|
||||||
model = model.to("cuda")
|
col_name = "episode_idx" if "episode_idx" in dataset.column_names else "ep_idx"
|
||||||
model = model.eval()
|
ep_indices, _ = np.unique(stats_dataset.get_col_data(col_name), return_index=True)
|
||||||
model.requires_grad_(False)
|
|
||||||
model.interpolate_pos_encoding = True
|
|
||||||
config = swm.PlanConfig(**cfg.plan_config)
|
|
||||||
solver = hydra.utils.instantiate(cfg.solver, model=model)
|
|
||||||
policy = swm.policy.WorldModelPolicy(
|
|
||||||
solver=solver, config=config, process=process, transform=transform
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
policy = swm.policy.RandomPolicy()
|
|
||||||
|
|
||||||
results_path = (
|
|
||||||
Path(swm.data.utils.get_cache_dir(), cfg.policy).parent
|
|
||||||
if cfg.policy != "random"
|
|
||||||
else Path(__file__).parent
|
|
||||||
)
|
|
||||||
|
|
||||||
# sample the episodes and the starting indices
|
|
||||||
episode_len = get_episodes_length(dataset, ep_indices)
|
episode_len = get_episodes_length(dataset, ep_indices)
|
||||||
max_start_idx = episode_len - cfg.eval.goal_offset_steps - 1
|
max_start_idx = episode_len - cfg.eval.goal_offset_steps - 1
|
||||||
max_start_idx_dict = {ep_id: max_start_idx[i] for i, ep_id in enumerate(ep_indices)}
|
max_start_idx_dict = {ep_id: max_start_idx[i] for i, ep_id in enumerate(ep_indices)}
|
||||||
# Map each dataset row’s episode_idx to its max_start_idx
|
|
||||||
col_name = "episode_idx" if "episode_idx" in dataset.column_names else "ep_idx"
|
|
||||||
max_start_per_row = np.array(
|
max_start_per_row = np.array(
|
||||||
[max_start_idx_dict[ep_id] for ep_id in dataset.get_col_data(col_name)]
|
[max_start_idx_dict[ep_id] for ep_id in dataset.get_col_data(col_name)]
|
||||||
)
|
)
|
||||||
|
|
||||||
# remove all the lines of dataset for which dataset['step_idx'] > max_start_per_row
|
|
||||||
valid_mask = dataset.get_col_data("step_idx") <= max_start_per_row
|
valid_mask = dataset.get_col_data("step_idx") <= max_start_per_row
|
||||||
valid_indices = np.nonzero(valid_mask)[0]
|
valid_indices = np.nonzero(valid_mask)[0]
|
||||||
print(valid_mask.sum(), "valid starting points found for evaluation.")
|
print(valid_mask.sum(), "valid starting points found for evaluation.")
|
||||||
@@ -124,35 +383,478 @@ def run(cfg: DictConfig):
|
|||||||
random_episode_indices = g.choice(
|
random_episode_indices = g.choice(
|
||||||
len(valid_indices) - 1, size=cfg.eval.num_eval, replace=False
|
len(valid_indices) - 1, size=cfg.eval.num_eval, replace=False
|
||||||
)
|
)
|
||||||
|
|
||||||
# sort increasingly to avoid issues with HDF5Dataset indexing
|
|
||||||
random_episode_indices = np.sort(valid_indices[random_episode_indices])
|
random_episode_indices = np.sort(valid_indices[random_episode_indices])
|
||||||
|
|
||||||
print(random_episode_indices)
|
print(random_episode_indices)
|
||||||
|
|
||||||
eval_episodes = dataset.get_row_data(random_episode_indices)[col_name]
|
rows = dataset.get_row_data(random_episode_indices)
|
||||||
eval_start_idx = dataset.get_row_data(random_episode_indices)["step_idx"]
|
eval_episodes = rows[col_name]
|
||||||
|
eval_start_idx = rows["step_idx"]
|
||||||
|
|
||||||
if len(eval_episodes) < cfg.eval.num_eval:
|
if len(eval_episodes) < cfg.eval.num_eval:
|
||||||
raise ValueError("Not enough episodes with sufficient length for evaluation.")
|
raise ValueError("Not enough episodes with sufficient length for evaluation.")
|
||||||
|
|
||||||
|
return eval_episodes, eval_start_idx
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_multi_gpu_devices(devices):
|
||||||
|
if devices is None:
|
||||||
|
return [f"cuda:{idx}" for idx in range(torch.cuda.device_count())]
|
||||||
|
|
||||||
|
normalized = []
|
||||||
|
for device in devices:
|
||||||
|
if isinstance(device, int):
|
||||||
|
normalized.append(f"cuda:{device}")
|
||||||
|
elif isinstance(device, str) and device.isdigit():
|
||||||
|
normalized.append(f"cuda:{int(device)}")
|
||||||
|
else:
|
||||||
|
normalized.append(str(device))
|
||||||
|
return normalized
|
||||||
|
|
||||||
|
|
||||||
|
def shard_eval_cases(eval_episodes, eval_start_idx, num_shards):
|
||||||
|
if num_shards < 1:
|
||||||
|
raise ValueError("num_shards must be >= 1")
|
||||||
|
|
||||||
|
total = len(eval_episodes)
|
||||||
|
shard_sizes = [total // num_shards] * num_shards
|
||||||
|
for idx in range(total % num_shards):
|
||||||
|
shard_sizes[idx] += 1
|
||||||
|
|
||||||
|
shards = []
|
||||||
|
start = 0
|
||||||
|
for size in shard_sizes:
|
||||||
|
end = start + size
|
||||||
|
if size > 0:
|
||||||
|
shards.append((eval_episodes[start:end], eval_start_idx[start:end]))
|
||||||
|
start = end
|
||||||
|
return shards
|
||||||
|
|
||||||
|
|
||||||
|
def get_rank_eval_subset(
|
||||||
|
eval_episodes,
|
||||||
|
eval_start_idx,
|
||||||
|
rank,
|
||||||
|
world_size,
|
||||||
|
*,
|
||||||
|
strategy="contiguous",
|
||||||
|
):
|
||||||
|
if world_size < 1:
|
||||||
|
raise ValueError("world_size must be >= 1")
|
||||||
|
if rank < 0 or rank >= world_size:
|
||||||
|
raise ValueError("rank must be in [0, world_size)")
|
||||||
|
|
||||||
|
if strategy == "round_robin":
|
||||||
|
episode_subset = eval_episodes[rank::world_size]
|
||||||
|
start_subset = eval_start_idx[rank::world_size]
|
||||||
|
return episode_subset, start_subset
|
||||||
|
if strategy != "contiguous":
|
||||||
|
raise ValueError("strategy must be one of: contiguous, round_robin")
|
||||||
|
|
||||||
|
total = len(eval_episodes)
|
||||||
|
shard_sizes = [total // world_size] * world_size
|
||||||
|
for idx in range(total % world_size):
|
||||||
|
shard_sizes[idx] += 1
|
||||||
|
|
||||||
|
start = sum(shard_sizes[:rank])
|
||||||
|
end = start + shard_sizes[rank]
|
||||||
|
return eval_episodes[start:end], eval_start_idx[start:end]
|
||||||
|
|
||||||
|
|
||||||
|
def run_eval_subset(
|
||||||
|
cfg: DictConfig,
|
||||||
|
eval_episodes,
|
||||||
|
eval_start_idx,
|
||||||
|
output_dir: Path,
|
||||||
|
*,
|
||||||
|
device_override: str | None = None,
|
||||||
|
enable_profile: bool = True,
|
||||||
|
enable_compile_warmup: bool = False,
|
||||||
|
before_evaluate=None,
|
||||||
|
):
|
||||||
|
local_cfg = OmegaConf.create(OmegaConf.to_container(cfg, resolve=False))
|
||||||
|
local_cfg.eval.num_eval = len(eval_episodes)
|
||||||
|
local_cfg.world.num_envs = len(eval_episodes)
|
||||||
|
local_cfg.world.max_episode_steps = 2 * local_cfg.eval.eval_budget
|
||||||
|
|
||||||
|
if device_override is not None:
|
||||||
|
local_cfg.solver.device = device_override
|
||||||
|
if torch.cuda.is_available() and str(device_override).startswith("cuda"):
|
||||||
|
torch.cuda.set_device(torch.device(device_override))
|
||||||
|
|
||||||
|
if not enable_profile:
|
||||||
|
if local_cfg.get("profile") is None:
|
||||||
|
local_cfg.profile = OmegaConf.create({"enabled": False})
|
||||||
|
else:
|
||||||
|
local_cfg.profile.enabled = False
|
||||||
|
|
||||||
|
world = swm.World(**local_cfg.world, image_shape=(224, 224))
|
||||||
|
transform = {
|
||||||
|
"pixels": img_transform(local_cfg),
|
||||||
|
"goal": img_transform(local_cfg),
|
||||||
|
}
|
||||||
|
dataset = get_dataset(local_cfg, local_cfg.eval.dataset_name)
|
||||||
|
process = build_process(local_cfg, dataset)
|
||||||
|
|
||||||
|
policy_name = local_cfg.get("policy", "random")
|
||||||
|
if policy_name != "random":
|
||||||
|
model = swm.policy.AutoCostModel(local_cfg.policy)
|
||||||
|
device = device_override or ("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
model = model.to(device)
|
||||||
|
model = model.eval()
|
||||||
|
model.requires_grad_(False)
|
||||||
|
model, compile_cfg, compile_target = maybe_compile_inference_target(
|
||||||
|
model, local_cfg, device
|
||||||
|
)
|
||||||
|
inference_ctx, inference_precision = get_inference_context(local_cfg, device)
|
||||||
|
model.interpolate_pos_encoding = True
|
||||||
|
config = swm.PlanConfig(**local_cfg.plan_config)
|
||||||
|
solver = hydra.utils.instantiate(local_cfg.solver, model=model)
|
||||||
|
policy = swm.policy.WorldModelPolicy(
|
||||||
|
solver=solver, config=config, process=process, transform=transform
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
policy = swm.policy.RandomPolicy()
|
||||||
|
solver = None
|
||||||
|
inference_ctx = nullcontext()
|
||||||
|
inference_precision = "fp32"
|
||||||
|
compile_cfg = get_compile_cfg(local_cfg)
|
||||||
|
compile_target = "disabled"
|
||||||
|
device = device_override or ("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
profiler_ctx, profile_dir, profile_cfg = make_profiler(local_cfg, output_dir)
|
||||||
world.set_policy(policy)
|
world.set_policy(policy)
|
||||||
|
|
||||||
|
if str(device).startswith("cuda") and torch.cuda.is_available():
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
if before_evaluate is not None:
|
||||||
|
before_evaluate()
|
||||||
|
if str(device).startswith("cuda") and torch.cuda.is_available():
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
def evaluate_subset(episodes, start_indices, *, eval_cfg=local_cfg):
|
||||||
|
return world.evaluate_from_dataset(
|
||||||
|
dataset,
|
||||||
|
start_steps=list(start_indices),
|
||||||
|
goal_offset_steps=eval_cfg.eval.goal_offset_steps,
|
||||||
|
eval_budget=eval_cfg.eval.eval_budget,
|
||||||
|
episodes_idx=list(episodes),
|
||||||
|
callables=OmegaConf.to_container(
|
||||||
|
eval_cfg.eval.get("callables"), resolve=True
|
||||||
|
),
|
||||||
|
save_video=bool(eval_cfg.eval.get("save_video", False)),
|
||||||
|
video_path=output_dir,
|
||||||
|
)
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
metrics = world.evaluate_from_dataset(
|
with get_eval_grad_context(solver):
|
||||||
dataset,
|
with profiler_ctx as profiler:
|
||||||
start_steps=eval_start_idx.tolist(),
|
with inference_ctx:
|
||||||
goal_offset_steps=cfg.eval.goal_offset_steps,
|
if enable_compile_warmup:
|
||||||
eval_budget=cfg.eval.eval_budget,
|
maybe_run_compile_warmup(cfg, eval_episodes, eval_start_idx)
|
||||||
episodes_idx=eval_episodes.tolist(),
|
with torch.profiler.record_function("eval.world_evaluate_from_dataset"):
|
||||||
callables=OmegaConf.to_container(cfg.eval.get("callables"), resolve=True),
|
metrics = evaluate_subset(eval_episodes, eval_start_idx)
|
||||||
video_path=results_path,
|
if str(device).startswith("cuda") and torch.cuda.is_available():
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
evaluation_time = time.time() - start_time
|
||||||
|
profile_summary_path = dump_profiler_results(profiler, profile_dir, profile_cfg)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"metrics": metrics,
|
||||||
|
"evaluation_time": evaluation_time,
|
||||||
|
"inference_precision": inference_precision,
|
||||||
|
"compile_target": compile_target,
|
||||||
|
"compile_mode": compile_cfg["mode"] if compile_target != "disabled" else None,
|
||||||
|
"profile_dir": profile_dir,
|
||||||
|
"profile_summary_path": profile_summary_path,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def maybe_run_compile_warmup(cfg, eval_episodes, eval_start_idx):
|
||||||
|
warmup_cfg = get_compile_warmup_cfg(cfg)
|
||||||
|
if not warmup_cfg["enabled"]:
|
||||||
|
return
|
||||||
|
|
||||||
|
if get_multi_gpu_cfg(cfg)["enabled"]:
|
||||||
|
print("Skipping compile warmup because multi_gpu.enabled=true uses spawned workers.")
|
||||||
|
return
|
||||||
|
|
||||||
|
if get_multi_node_cfg(cfg)["enabled"]:
|
||||||
|
rank, world_size, local_rank = get_rank_context(cfg)
|
||||||
|
eval_episodes, eval_start_idx = get_rank_eval_subset(
|
||||||
|
eval_episodes, eval_start_idx, rank, world_size
|
||||||
|
)
|
||||||
|
device_override = f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu"
|
||||||
|
else:
|
||||||
|
device_override = None
|
||||||
|
|
||||||
|
warmup_count = min(int(warmup_cfg["num_eval"]), len(eval_episodes))
|
||||||
|
if warmup_count < 1:
|
||||||
|
return
|
||||||
|
|
||||||
|
warmup_eval_cfg = OmegaConf.create(OmegaConf.to_container(cfg, resolve=False))
|
||||||
|
warmup_eval_cfg.eval.num_eval = warmup_count
|
||||||
|
warmup_eval_cfg.eval.save_video = False
|
||||||
|
if warmup_eval_cfg.get("profile") is None:
|
||||||
|
warmup_eval_cfg.profile = OmegaConf.create({"enabled": False})
|
||||||
|
else:
|
||||||
|
warmup_eval_cfg.profile.enabled = False
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory(prefix="lewm_compile_warmup_") as tmpdir:
|
||||||
|
run_eval_subset(
|
||||||
|
warmup_eval_cfg,
|
||||||
|
list(eval_episodes[:warmup_count]),
|
||||||
|
list(eval_start_idx[:warmup_count]),
|
||||||
|
Path(tmpdir),
|
||||||
|
device_override=device_override,
|
||||||
|
enable_profile=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def _multi_gpu_eval_worker(
|
||||||
|
cfg_container,
|
||||||
|
eval_episodes,
|
||||||
|
eval_start_idx,
|
||||||
|
output_dir,
|
||||||
|
device,
|
||||||
|
shard_idx,
|
||||||
|
queue,
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
cfg = OmegaConf.create(cfg_container)
|
||||||
|
result = run_eval_subset(
|
||||||
|
cfg,
|
||||||
|
eval_episodes,
|
||||||
|
eval_start_idx,
|
||||||
|
Path(output_dir),
|
||||||
|
device_override=device,
|
||||||
|
enable_profile=False,
|
||||||
|
)
|
||||||
|
queue.put({"ok": True, "shard_idx": shard_idx, "result": result})
|
||||||
|
except Exception:
|
||||||
|
queue.put(
|
||||||
|
{
|
||||||
|
"ok": False,
|
||||||
|
"shard_idx": shard_idx,
|
||||||
|
"error": traceback.format_exc(),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def run_multi_gpu_eval(cfg, eval_episodes, eval_start_idx, output_dir: Path):
|
||||||
|
multi_gpu_cfg = get_multi_gpu_cfg(cfg)
|
||||||
|
devices = normalize_multi_gpu_devices(multi_gpu_cfg["devices"])
|
||||||
|
if len(devices) < 2:
|
||||||
|
raise ValueError("multi_gpu.enabled=true requires at least 2 CUDA devices")
|
||||||
|
|
||||||
|
shards = shard_eval_cases(eval_episodes, eval_start_idx, min(len(devices), len(eval_episodes)))
|
||||||
|
devices = devices[: len(shards)]
|
||||||
|
|
||||||
|
ctx = mp.get_context(multi_gpu_cfg["start_method"])
|
||||||
|
queue = ctx.Queue()
|
||||||
|
cfg_container = OmegaConf.to_container(cfg, resolve=False)
|
||||||
|
processes = []
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
for shard_idx, ((shard_episodes, shard_start_idx), device) in enumerate(
|
||||||
|
zip(shards, devices, strict=True)
|
||||||
|
):
|
||||||
|
process = ctx.Process(
|
||||||
|
target=_multi_gpu_eval_worker,
|
||||||
|
args=(
|
||||||
|
cfg_container,
|
||||||
|
list(shard_episodes),
|
||||||
|
list(shard_start_idx),
|
||||||
|
str(output_dir),
|
||||||
|
device,
|
||||||
|
shard_idx,
|
||||||
|
queue,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
process.start()
|
||||||
|
processes.append(process)
|
||||||
|
|
||||||
|
shard_results = {}
|
||||||
|
errors = []
|
||||||
|
for _ in processes:
|
||||||
|
message = queue.get()
|
||||||
|
if message["ok"]:
|
||||||
|
shard_results[message["shard_idx"]] = message["result"]
|
||||||
|
else:
|
||||||
|
errors.append(message["error"])
|
||||||
|
|
||||||
|
for process in processes:
|
||||||
|
process.join()
|
||||||
|
|
||||||
|
if errors:
|
||||||
|
raise RuntimeError(errors[0])
|
||||||
|
|
||||||
|
ordered_results = [shard_results[idx] for idx in range(len(processes))]
|
||||||
|
episode_successes = np.concatenate(
|
||||||
|
[
|
||||||
|
np.asarray(result["metrics"]["episode_successes"], dtype=np.bool_)
|
||||||
|
for result in ordered_results
|
||||||
|
]
|
||||||
)
|
)
|
||||||
end_time = time.time()
|
|
||||||
|
seeds = None
|
||||||
|
shard_seeds = [result["metrics"].get("seeds") for result in ordered_results]
|
||||||
|
if all(seed is not None for seed in shard_seeds):
|
||||||
|
seeds = np.concatenate(shard_seeds)
|
||||||
|
|
||||||
|
metrics = {
|
||||||
|
"success_rate": float(np.sum(episode_successes)) / len(episode_successes) * 100.0,
|
||||||
|
"episode_successes": episode_successes,
|
||||||
|
"seeds": seeds,
|
||||||
|
}
|
||||||
|
reference = ordered_results[0]
|
||||||
|
return {
|
||||||
|
"metrics": metrics,
|
||||||
|
"evaluation_time": time.time() - start_time,
|
||||||
|
"inference_precision": reference["inference_precision"],
|
||||||
|
"compile_target": reference["compile_target"],
|
||||||
|
"compile_mode": reference["compile_mode"],
|
||||||
|
"profile_dir": None,
|
||||||
|
"profile_summary_path": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def combine_eval_results(ordered_results):
|
||||||
|
episode_successes = np.concatenate(
|
||||||
|
[
|
||||||
|
np.asarray(result["metrics"]["episode_successes"], dtype=np.bool_)
|
||||||
|
for result in ordered_results
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
seeds = None
|
||||||
|
shard_seeds = [result["metrics"].get("seeds") for result in ordered_results]
|
||||||
|
if all(seed is not None for seed in shard_seeds):
|
||||||
|
seeds = np.concatenate(shard_seeds)
|
||||||
|
|
||||||
|
metrics = {
|
||||||
|
"success_rate": float(np.sum(episode_successes)) / len(episode_successes) * 100.0,
|
||||||
|
"episode_successes": episode_successes,
|
||||||
|
"seeds": seeds,
|
||||||
|
}
|
||||||
|
reference = ordered_results[0]
|
||||||
|
return metrics, reference
|
||||||
|
|
||||||
|
|
||||||
|
def run_multi_node_eval(cfg, eval_episodes, eval_start_idx, output_dir: Path):
|
||||||
|
rank, world_size, local_rank = get_rank_context(cfg)
|
||||||
|
multi_node_cfg = get_multi_node_cfg(cfg)
|
||||||
|
shard_episodes, shard_start_idx = get_rank_eval_subset(
|
||||||
|
eval_episodes,
|
||||||
|
eval_start_idx,
|
||||||
|
rank,
|
||||||
|
world_size,
|
||||||
|
strategy=multi_node_cfg["shard_strategy"],
|
||||||
|
)
|
||||||
|
if len(shard_episodes) == 0:
|
||||||
|
raise ValueError("No evaluation episodes assigned to this rank")
|
||||||
|
|
||||||
|
local_cfg = OmegaConf.create(OmegaConf.to_container(cfg, resolve=False))
|
||||||
|
local_cfg.multi_node.enabled = False
|
||||||
|
if local_cfg.get("multi_gpu") is None:
|
||||||
|
local_cfg.multi_gpu = OmegaConf.create({"enabled": False})
|
||||||
|
else:
|
||||||
|
local_cfg.multi_gpu.enabled = False
|
||||||
|
|
||||||
|
device = f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu"
|
||||||
|
preload_cfg = get_preload_wait_cfg(cfg)
|
||||||
|
if preload_cfg["enabled"]:
|
||||||
|
if not torch.distributed.is_available():
|
||||||
|
raise RuntimeError("torch.distributed is required for preload_wait")
|
||||||
|
if not torch.distributed.is_initialized():
|
||||||
|
torch.distributed.init_process_group(backend=multi_node_cfg["backend"])
|
||||||
|
|
||||||
|
rank_output_path = get_rank_result_path(output_dir, cfg, rank)
|
||||||
|
result = run_eval_subset(
|
||||||
|
local_cfg,
|
||||||
|
list(shard_episodes),
|
||||||
|
list(shard_start_idx),
|
||||||
|
rank_output_path.parent,
|
||||||
|
device_override=device,
|
||||||
|
enable_profile=False,
|
||||||
|
before_evaluate=lambda: wait_for_preload_signal(cfg, rank=rank),
|
||||||
|
)
|
||||||
|
if not multi_node_cfg["aggregate_results"]:
|
||||||
|
result["output_filename"] = rank_output_path.name
|
||||||
|
finalize_multi_node_process_group(cfg)
|
||||||
|
return result
|
||||||
|
|
||||||
|
if not torch.distributed.is_available():
|
||||||
|
raise RuntimeError("torch.distributed is required for multi-node evaluation")
|
||||||
|
if not torch.distributed.is_initialized():
|
||||||
|
torch.distributed.init_process_group(backend=multi_node_cfg["backend"])
|
||||||
|
|
||||||
|
gathered = all_gather_eval_result(result)
|
||||||
|
metrics, reference = combine_eval_results(gathered)
|
||||||
|
combined = {
|
||||||
|
"metrics": metrics,
|
||||||
|
"evaluation_time": max(item["evaluation_time"] for item in gathered),
|
||||||
|
"inference_precision": reference["inference_precision"],
|
||||||
|
"compile_target": reference["compile_target"],
|
||||||
|
"compile_mode": reference["compile_mode"],
|
||||||
|
"profile_dir": None,
|
||||||
|
"profile_summary_path": None,
|
||||||
|
"output_filename": cfg.output.filename,
|
||||||
|
}
|
||||||
|
if multi_node_cfg["sync_before_return"]:
|
||||||
|
torch.distributed.barrier()
|
||||||
|
finalize_multi_node_process_group(cfg)
|
||||||
|
if rank != 0:
|
||||||
|
return None
|
||||||
|
return combined
|
||||||
|
|
||||||
|
@hydra.main(version_base=None, config_path="./config/eval", config_name="pusht")
|
||||||
|
def run(cfg: DictConfig):
|
||||||
|
"""Run evaluation of dinowm vs random policy."""
|
||||||
|
assert (
|
||||||
|
cfg.plan_config.horizon * cfg.plan_config.action_block <= cfg.eval.eval_budget
|
||||||
|
), "Planning horizon must be smaller than or equal to eval_budget"
|
||||||
|
|
||||||
|
dataset = get_dataset(cfg, cfg.eval.dataset_name)
|
||||||
|
eval_episodes, eval_start_idx = sample_eval_cases(cfg, dataset)
|
||||||
|
output_dir = Path.cwd().resolve()
|
||||||
|
profile_cfg = get_profile_cfg(cfg)
|
||||||
|
|
||||||
|
maybe_run_compile_warmup(cfg, eval_episodes, eval_start_idx)
|
||||||
|
eval_wall_start = time.time()
|
||||||
|
|
||||||
|
if get_multi_node_cfg(cfg)["enabled"] and get_multi_gpu_cfg(cfg)["enabled"]:
|
||||||
|
raise ValueError("multi_node.enabled and multi_gpu.enabled are mutually exclusive")
|
||||||
|
|
||||||
|
if get_multi_node_cfg(cfg)["enabled"]:
|
||||||
|
eval_result = run_multi_node_eval(cfg, eval_episodes, eval_start_idx, output_dir)
|
||||||
|
if eval_result is None:
|
||||||
|
return
|
||||||
|
elif get_multi_gpu_cfg(cfg)["enabled"]:
|
||||||
|
if profile_cfg["enabled"]:
|
||||||
|
raise ValueError("Profiling is not supported together with multi_gpu.enabled=true")
|
||||||
|
eval_result = run_multi_gpu_eval(cfg, eval_episodes, eval_start_idx, output_dir)
|
||||||
|
else:
|
||||||
|
eval_result = run_eval_subset(
|
||||||
|
cfg,
|
||||||
|
eval_episodes.tolist(),
|
||||||
|
eval_start_idx.tolist(),
|
||||||
|
output_dir,
|
||||||
|
)
|
||||||
|
|
||||||
|
metrics = eval_result["metrics"]
|
||||||
|
evaluation_time = eval_result["evaluation_time"]
|
||||||
|
inference_precision = eval_result["inference_precision"]
|
||||||
|
compile_target = eval_result["compile_target"]
|
||||||
|
compile_mode = eval_result["compile_mode"]
|
||||||
|
profile_dir = eval_result["profile_dir"]
|
||||||
|
profile_summary_path = eval_result["profile_summary_path"]
|
||||||
|
output_filename = eval_result.get("output_filename", cfg.output.filename)
|
||||||
|
|
||||||
print(metrics)
|
print(metrics)
|
||||||
|
|
||||||
results_path = results_path / cfg.output.filename
|
results_path = output_dir / output_filename
|
||||||
results_path.parent.mkdir(parents=True, exist_ok=True)
|
results_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
with results_path.open("a") as f:
|
with results_path.open("a") as f:
|
||||||
@@ -164,7 +866,17 @@ def run(cfg: DictConfig):
|
|||||||
|
|
||||||
f.write("==== RESULTS ====\n")
|
f.write("==== RESULTS ====\n")
|
||||||
f.write(f"metrics: {metrics}\n")
|
f.write(f"metrics: {metrics}\n")
|
||||||
f.write(f"evaluation_time: {end_time - start_time} seconds\n")
|
f.write(f"evaluation_time: {evaluation_time} seconds\n")
|
||||||
|
f.write(f"inference_precision: {inference_precision}\n")
|
||||||
|
f.write(f"inference_compile_target: {compile_target}\n")
|
||||||
|
if compile_target != "disabled":
|
||||||
|
f.write(f"inference_compile_mode: {compile_mode}\n")
|
||||||
|
if profile_cfg["enabled"]:
|
||||||
|
f.write(f"profile_dir: {profile_dir}\n")
|
||||||
|
if profile_summary_path is not None:
|
||||||
|
f.write(f"profile_summary: {profile_summary_path}\n")
|
||||||
|
|
||||||
|
f.write(f"total_wall_time: {time.time() - eval_wall_start} seconds\n")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
320
jepa.py
320
jepa.py
@@ -2,12 +2,8 @@
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from einops import rearrange
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
def detach_clone(v):
|
|
||||||
return v.detach().clone() if torch.is_tensor(v) else v
|
|
||||||
|
|
||||||
class JEPA(nn.Module):
|
class JEPA(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -25,34 +21,124 @@ class JEPA(nn.Module):
|
|||||||
self.action_encoder = action_encoder
|
self.action_encoder = action_encoder
|
||||||
self.projector = projector or nn.Identity()
|
self.projector = projector or nn.Identity()
|
||||||
self.pred_proj = pred_proj or nn.Identity()
|
self.pred_proj = pred_proj or nn.Identity()
|
||||||
|
self._cached_device_tensors = {}
|
||||||
|
self._cached_init_signature = None
|
||||||
|
self._cached_init_emb = None
|
||||||
|
self._cached_goal_signature = None
|
||||||
|
self._cached_goal_emb = None
|
||||||
|
|
||||||
|
def _ensure_runtime_caches(self):
|
||||||
|
if not hasattr(self, "_cached_device_tensors"):
|
||||||
|
self._cached_device_tensors = {}
|
||||||
|
if not hasattr(self, "_cached_init_signature"):
|
||||||
|
self._cached_init_signature = None
|
||||||
|
if not hasattr(self, "_cached_init_emb"):
|
||||||
|
self._cached_init_emb = None
|
||||||
|
if not hasattr(self, "_cached_goal_signature"):
|
||||||
|
self._cached_goal_signature = None
|
||||||
|
if not hasattr(self, "_cached_goal_emb"):
|
||||||
|
self._cached_goal_emb = None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _tensor_signature(tensor: torch.Tensor):
|
||||||
|
try:
|
||||||
|
version = tensor._version
|
||||||
|
except RuntimeError:
|
||||||
|
version = None
|
||||||
|
return (
|
||||||
|
str(tensor.device),
|
||||||
|
tensor.dtype,
|
||||||
|
tuple(tensor.shape),
|
||||||
|
tuple(tensor.stride()),
|
||||||
|
tensor.storage_offset(),
|
||||||
|
tensor.data_ptr(),
|
||||||
|
version,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_cached_device_tensor(
|
||||||
|
self,
|
||||||
|
key: str,
|
||||||
|
tensor: torch.Tensor,
|
||||||
|
device: torch.device,
|
||||||
|
*,
|
||||||
|
ensure_contiguous: bool = False,
|
||||||
|
):
|
||||||
|
self._ensure_runtime_caches()
|
||||||
|
if tensor.device == device and (not ensure_contiguous or tensor.is_contiguous()):
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
signature = (self._tensor_signature(tensor), str(device), ensure_contiguous)
|
||||||
|
cached = self._cached_device_tensors.get(key)
|
||||||
|
if cached is None or cached[0] != signature:
|
||||||
|
prepared = tensor.to(device, non_blocking=True)
|
||||||
|
if ensure_contiguous and not prepared.is_contiguous():
|
||||||
|
prepared = prepared.contiguous()
|
||||||
|
self._cached_device_tensors[key] = (
|
||||||
|
signature,
|
||||||
|
prepared,
|
||||||
|
)
|
||||||
|
return self._cached_device_tensors[key][1]
|
||||||
|
|
||||||
|
def _ensure_info_device(self, info_dict: dict, device: torch.device):
|
||||||
|
for key, value in list(info_dict.items()):
|
||||||
|
if key.startswith("_lewm_"):
|
||||||
|
continue
|
||||||
|
if torch.is_tensor(value):
|
||||||
|
info_dict[key] = self._get_cached_device_tensor(
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
device,
|
||||||
|
ensure_contiguous=True,
|
||||||
|
)
|
||||||
|
return info_dict
|
||||||
|
|
||||||
|
def _get_cached_init_emb(self, info_dict: dict):
|
||||||
|
self._ensure_runtime_caches()
|
||||||
|
pixels = info_dict["pixels"]
|
||||||
|
signature = self._tensor_signature(pixels)
|
||||||
|
if self._cached_init_signature != signature:
|
||||||
|
init_info = {"pixels": pixels[:, 0]}
|
||||||
|
self._cached_init_emb = self.encode(init_info)["emb"].detach()
|
||||||
|
self._cached_init_signature = signature
|
||||||
|
return self._cached_init_emb
|
||||||
|
|
||||||
|
def _get_cached_goal_emb(self, info_dict: dict):
|
||||||
|
self._ensure_runtime_caches()
|
||||||
|
goal = info_dict["goal"]
|
||||||
|
signature = self._tensor_signature(goal)
|
||||||
|
if self._cached_goal_signature != signature:
|
||||||
|
goal_info = {"pixels": goal[:, 0]}
|
||||||
|
self._cached_goal_emb = self.encode(goal_info)["emb"][:, -1:, :].detach()
|
||||||
|
self._cached_goal_signature = signature
|
||||||
|
return self._cached_goal_emb
|
||||||
|
|
||||||
def encode(self, info):
|
def encode(self, info):
|
||||||
"""Encode observations and actions into embeddings.
|
"""Encode observations and actions into embeddings.
|
||||||
info: dict with pixels and action keys
|
info: dict with pixels and action keys
|
||||||
"""
|
"""
|
||||||
|
with torch.profiler.record_function("lewm.encode"):
|
||||||
|
pixels = info['pixels'].float()
|
||||||
|
b, t = pixels.shape[:2]
|
||||||
|
pixels = pixels.reshape(b * t, *pixels.shape[2:]) # flatten for encoding
|
||||||
|
output = self.encoder(pixels, interpolate_pos_encoding=True)
|
||||||
|
pixels_emb = output.last_hidden_state[:, 0] # cls token
|
||||||
|
emb = self.projector(pixels_emb)
|
||||||
|
info["emb"] = emb.reshape(b, t, -1)
|
||||||
|
|
||||||
pixels = info['pixels'].float()
|
if "action" in info:
|
||||||
b = pixels.size(0)
|
info["act_emb"] = self.action_encoder(info["action"])
|
||||||
pixels = rearrange(pixels, "b t ... -> (b t) ...") # flatten for encoding
|
|
||||||
output = self.encoder(pixels, interpolate_pos_encoding=True)
|
|
||||||
pixels_emb = output.last_hidden_state[:, 0] # cls token
|
|
||||||
emb = self.projector(pixels_emb)
|
|
||||||
info["emb"] = rearrange(emb, "(b t) d -> b t d", b=b)
|
|
||||||
|
|
||||||
if "action" in info:
|
return info
|
||||||
info["act_emb"] = self.action_encoder(info["action"])
|
|
||||||
|
|
||||||
return info
|
|
||||||
|
|
||||||
def predict(self, emb, act_emb):
|
def predict(self, emb, act_emb):
|
||||||
"""Predict next state embedding
|
"""Predict next state embedding
|
||||||
emb: (B, T, D)
|
emb: (B, T, D)
|
||||||
act_emb: (B, T, A_emb)
|
act_emb: (B, T, A_emb)
|
||||||
"""
|
"""
|
||||||
preds = self.predictor(emb, act_emb)
|
with torch.profiler.record_function("lewm.predict"):
|
||||||
preds = self.pred_proj(rearrange(preds, "b t d -> (b t) d"))
|
preds = self.predictor(emb, act_emb)
|
||||||
preds = rearrange(preds, "(b t) d -> b t d", b=emb.size(0))
|
preds = self.pred_proj(preds)
|
||||||
return preds
|
return preds
|
||||||
|
|
||||||
####################
|
####################
|
||||||
## Inference only ##
|
## Inference only ##
|
||||||
@@ -65,89 +151,151 @@ class JEPA(nn.Module):
|
|||||||
- S is the number of action plan samples
|
- S is the number of action plan samples
|
||||||
- T is the time horizon
|
- T is the time horizon
|
||||||
"""
|
"""
|
||||||
|
with torch.profiler.record_function("lewm.rollout"):
|
||||||
|
assert "pixels" in info, "pixels not in info_dict"
|
||||||
|
if history_size < 1:
|
||||||
|
raise ValueError("history_size must be >= 1")
|
||||||
|
|
||||||
assert "pixels" in info, "pixels not in info_dict"
|
H = info["pixels"].size(2)
|
||||||
H = info["pixels"].size(2)
|
B, S, T = action_sequence.shape[:3]
|
||||||
B, S, T = action_sequence.shape[:3]
|
if T < H:
|
||||||
act_0, act_future = torch.split(action_sequence, [H, T - H], dim=2)
|
raise ValueError(
|
||||||
info["action"] = act_0
|
f"action_sequence horizon ({T}) must be >= history length ({H})"
|
||||||
n_steps = T - H
|
)
|
||||||
|
|
||||||
# copy and encode initial info dict
|
# Cache the encoded initial state across solver iterations.
|
||||||
_init = {k: v[:, 0] for k, v in info.items() if torch.is_tensor(v)}
|
init_emb = self._get_cached_init_emb(info)
|
||||||
_init = self.encode(_init)
|
HS = history_size
|
||||||
emb = info["emb"] = _init["emb"].unsqueeze(1).expand(B, S, -1, -1)
|
hist_len = min(HS, init_emb.size(1), H)
|
||||||
_init = {k: detach_clone(v) for k, v in _init.items()}
|
if hist_len < 1:
|
||||||
|
raise ValueError("rollout requires at least one history step")
|
||||||
|
|
||||||
# flatten batch and sample dimensions for rollout
|
init_hist = init_emb[:, -hist_len:]
|
||||||
emb = rearrange(emb, "b s ... -> (b s) ...").clone()
|
init_hist = init_hist.unsqueeze(1).expand(-1, S, -1, -1)
|
||||||
act = rearrange(act_0, "b s ... -> (b s) ...")
|
init_hist = init_hist.reshape(B * S, hist_len, init_hist.size(-1)).contiguous()
|
||||||
act_future = rearrange(act_future, "b s ... -> (b s) ...")
|
|
||||||
|
|
||||||
# rollout predictor autoregressively for n_steps
|
flat_actions = action_sequence.contiguous().view(B * S, T, -1)
|
||||||
HS = history_size
|
action_emb = self.action_encoder(flat_actions)
|
||||||
for t in range(n_steps):
|
act_hist = action_emb[:, H - hist_len : H]
|
||||||
act_emb = self.action_encoder(act)
|
act_future = action_emb[:, H:]
|
||||||
emb_trunc = emb[:, -HS:] # (BS, HS, D)
|
|
||||||
act_trunc = act_emb[:, -HS:] # (BS, HS, A_emb)
|
|
||||||
pred_emb = self.predict(emb_trunc, act_trunc)[:, -1:] # (BS, 1, D)
|
|
||||||
emb = torch.cat([emb, pred_emb], dim=1) # (BS, T+1, D)
|
|
||||||
|
|
||||||
next_act = act_future[:, t : t + 1, :] # (BS, 1, action_dim)
|
if HS == 1:
|
||||||
act = torch.cat([act, next_act], dim=1) # (BS, T+1, action_dim)
|
emb_hist = init_hist[:, -1:]
|
||||||
|
act_emb_hist = act_hist[:, -1:]
|
||||||
|
|
||||||
# predict the last state
|
for t in range(act_future.size(1)):
|
||||||
act_emb = self.action_encoder(act) # (BS, T, A_emb)
|
emb_hist = self.predict(emb_hist, act_emb_hist)[:, -1:]
|
||||||
emb_trunc = emb[:, -HS:] # (BS, HS, D)
|
act_emb_hist = act_future[:, t : t + 1]
|
||||||
act_trunc = act_emb[:, -HS:] # (BS, HS, A_emb)
|
|
||||||
pred_emb = self.predict(emb_trunc, act_trunc)[:, -1:] # (BS, 1, D)
|
|
||||||
emb = torch.cat([emb, pred_emb], dim=1)
|
|
||||||
|
|
||||||
# unflatten batch and sample dimensions
|
pred_rollout = self.predict(emb_hist, act_emb_hist)[:, -1:]
|
||||||
pred_rollout = rearrange(emb, "(b s) ... -> b s ...", b=B, s=S)
|
else:
|
||||||
info["predicted_emb"] = pred_rollout
|
if torch.is_grad_enabled() and action_sequence.requires_grad:
|
||||||
|
emb_slots = init_hist.split(1, dim=1)
|
||||||
|
act_slots = act_hist.split(1, dim=1)
|
||||||
|
|
||||||
return info
|
for t in range(act_future.size(1)):
|
||||||
|
emb_view = torch.cat(emb_slots[-HS:], dim=1)
|
||||||
|
act_view = torch.cat(act_slots[-HS:], dim=1)
|
||||||
|
pred_emb = self.predict(emb_view, act_view)[:, -1:]
|
||||||
|
next_act_emb = act_future[:, t : t + 1]
|
||||||
|
|
||||||
|
emb_slots = (*emb_slots[-(HS - 1) :], pred_emb)
|
||||||
|
act_slots = (*act_slots[-(HS - 1) :], next_act_emb)
|
||||||
|
|
||||||
|
emb_view = torch.cat(emb_slots[-HS:], dim=1)
|
||||||
|
act_view = torch.cat(act_slots[-HS:], dim=1)
|
||||||
|
pred_rollout = self.predict(emb_view, act_view)[:, -1:]
|
||||||
|
info["predicted_emb"] = pred_rollout.reshape(
|
||||||
|
B, S, *pred_rollout.shape[1:]
|
||||||
|
)
|
||||||
|
return info
|
||||||
|
|
||||||
|
emb_hist = init_hist.new_empty((B * S, HS, init_hist.size(-1)))
|
||||||
|
act_emb_hist = action_emb.new_empty((B * S, HS, action_emb.size(-1)))
|
||||||
|
emb_hist[:, :hist_len].copy_(init_hist)
|
||||||
|
act_emb_hist[:, :hist_len].copy_(act_hist)
|
||||||
|
|
||||||
|
history_order = torch.stack(
|
||||||
|
[
|
||||||
|
(torch.arange(HS, device=action_emb.device) + offset) % HS
|
||||||
|
for offset in range(HS)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
filled = hist_len
|
||||||
|
next_slot = hist_len % HS
|
||||||
|
|
||||||
|
for t in range(act_future.size(1)):
|
||||||
|
if filled < HS:
|
||||||
|
emb_view = emb_hist[:, :filled]
|
||||||
|
act_view = act_emb_hist[:, :filled]
|
||||||
|
elif next_slot == 0:
|
||||||
|
emb_view = emb_hist
|
||||||
|
act_view = act_emb_hist
|
||||||
|
else:
|
||||||
|
order = history_order[next_slot]
|
||||||
|
emb_view = emb_hist.index_select(1, order)
|
||||||
|
act_view = act_emb_hist.index_select(1, order)
|
||||||
|
|
||||||
|
pred_emb = self.predict(emb_view, act_view)[:, -1:]
|
||||||
|
next_act_emb = act_future[:, t : t + 1]
|
||||||
|
emb_hist[:, next_slot : next_slot + 1].copy_(pred_emb)
|
||||||
|
act_emb_hist[:, next_slot : next_slot + 1].copy_(next_act_emb)
|
||||||
|
|
||||||
|
if filled < HS:
|
||||||
|
filled += 1
|
||||||
|
next_slot = (next_slot + 1) % HS
|
||||||
|
|
||||||
|
if filled < HS:
|
||||||
|
emb_view = emb_hist[:, :filled]
|
||||||
|
act_view = act_emb_hist[:, :filled]
|
||||||
|
elif next_slot == 0:
|
||||||
|
emb_view = emb_hist
|
||||||
|
act_view = act_emb_hist
|
||||||
|
else:
|
||||||
|
order = history_order[next_slot]
|
||||||
|
emb_view = emb_hist.index_select(1, order)
|
||||||
|
act_view = act_emb_hist.index_select(1, order)
|
||||||
|
|
||||||
|
pred_rollout = self.predict(emb_view, act_view)[:, -1:]
|
||||||
|
info["predicted_emb"] = pred_rollout.reshape(B, S, *pred_rollout.shape[1:])
|
||||||
|
|
||||||
|
return info
|
||||||
|
|
||||||
def criterion(self, info_dict: dict):
|
def criterion(self, info_dict: dict):
|
||||||
"""Compute the cost between predicted embeddings and goal embeddings."""
|
"""Compute the cost between predicted embeddings and goal embeddings."""
|
||||||
pred_emb = info_dict["predicted_emb"] # (B,S, T-1, dim)
|
with torch.profiler.record_function("lewm.criterion"):
|
||||||
goal_emb = info_dict["goal_emb"] # (B, S, T, dim)
|
pred_emb = info_dict["predicted_emb"] # (B,S, T-1, dim)
|
||||||
|
goal_emb = info_dict["goal_emb"] # (B, S, T, dim)
|
||||||
|
if goal_emb.ndim == pred_emb.ndim - 1:
|
||||||
|
goal_emb = goal_emb.unsqueeze(1)
|
||||||
|
|
||||||
goal_emb = goal_emb[..., -1:, :].expand_as(pred_emb)
|
# return last-step cost per action candidate
|
||||||
|
cost = F.mse_loss(
|
||||||
|
pred_emb[..., -1:, :],
|
||||||
|
goal_emb[..., -1:, :].detach(),
|
||||||
|
reduction="none",
|
||||||
|
).sum(dim=tuple(range(2, pred_emb.ndim))) # (B, S)
|
||||||
|
|
||||||
# return last-step cost per action candidate
|
return cost
|
||||||
cost = F.mse_loss(
|
|
||||||
pred_emb[..., -1:, :],
|
|
||||||
goal_emb[..., -1:, :].detach(),
|
|
||||||
reduction="none",
|
|
||||||
).sum(dim=tuple(range(2, pred_emb.ndim))) # (B, S)
|
|
||||||
|
|
||||||
return cost
|
|
||||||
|
|
||||||
def get_cost(self, info_dict: dict, action_candidates: torch.Tensor):
|
def get_cost(self, info_dict: dict, action_candidates: torch.Tensor):
|
||||||
""" Compute the cost of action candidates given an info dict with goal and initial state."""
|
""" Compute the cost of action candidates given an info dict with goal and initial state."""
|
||||||
|
with torch.profiler.record_function("lewm.get_cost"):
|
||||||
|
assert "goal" in info_dict, "goal not in info_dict"
|
||||||
|
|
||||||
assert "goal" in info_dict, "goal not in info_dict"
|
self._ensure_runtime_caches()
|
||||||
|
device = next(self.parameters()).device
|
||||||
|
info_dict = self._ensure_info_device(info_dict, device)
|
||||||
|
action_candidates = self._get_cached_device_tensor(
|
||||||
|
"_lewm_action_candidates",
|
||||||
|
action_candidates,
|
||||||
|
device,
|
||||||
|
ensure_contiguous=True,
|
||||||
|
)
|
||||||
|
|
||||||
device = next(self.parameters()).device
|
info_dict["goal_emb"] = self._get_cached_goal_emb(info_dict)
|
||||||
for k in list(info_dict.keys()):
|
info_dict = self.rollout(info_dict, action_candidates)
|
||||||
if torch.is_tensor(info_dict[k]):
|
|
||||||
info_dict[k] = info_dict[k].to(device)
|
|
||||||
|
|
||||||
goal = {k: v[:, 0] for k, v in info_dict.items() if torch.is_tensor(v)}
|
cost = self.criterion(info_dict)
|
||||||
goal["pixels"] = goal["goal"]
|
|
||||||
|
return cost
|
||||||
for k in info_dict:
|
|
||||||
if k.startswith("goal_"):
|
|
||||||
goal[k[len("goal_") :]] = goal.pop(k)
|
|
||||||
|
|
||||||
goal.pop("action")
|
|
||||||
goal = self.encode(goal)
|
|
||||||
|
|
||||||
info_dict["goal_emb"] = goal["emb"]
|
|
||||||
info_dict = self.rollout(info_dict, action_candidates)
|
|
||||||
|
|
||||||
cost = self.criterion(info_dict)
|
|
||||||
|
|
||||||
return cost
|
|
||||||
|
|||||||
@@ -236,9 +236,13 @@ class MLP(nn.Module):
|
|||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
"""
|
"""
|
||||||
x: (B*T, D)
|
x: (..., D)
|
||||||
"""
|
"""
|
||||||
return self.net(x)
|
if x.ndim <= 2:
|
||||||
|
return self.net(x)
|
||||||
|
|
||||||
|
output = self.net(x.reshape(-1, x.size(-1)))
|
||||||
|
return output.reshape(*x.shape[:-1], output.size(-1))
|
||||||
|
|
||||||
|
|
||||||
class ARPredictor(nn.Module):
|
class ARPredictor(nn.Module):
|
||||||
|
|||||||
1624
pusht_results.txt
Normal file
1624
pusht_results.txt
Normal file
File diff suppressed because it is too large
Load Diff
131
scripts/convert_hf_checkpoint.py
Normal file
131
scripts/convert_hf_checkpoint.py
Normal file
@@ -0,0 +1,131 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
"""Convert LeWM HuggingFace weights into eval-compatible object checkpoints."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import stable_pretraining as spt
|
||||||
|
import torch
|
||||||
|
|
||||||
|
REPO_ROOT = Path(__file__).resolve().parents[1]
|
||||||
|
if str(REPO_ROOT) not in sys.path:
|
||||||
|
sys.path.insert(0, str(REPO_ROOT))
|
||||||
|
|
||||||
|
from jepa import JEPA
|
||||||
|
from module import ARPredictor, Embedder, MLP
|
||||||
|
|
||||||
|
|
||||||
|
def _load_json(path: Path) -> dict:
|
||||||
|
with path.open("r", encoding="utf-8") as f:
|
||||||
|
return json.load(f)
|
||||||
|
|
||||||
|
|
||||||
|
def _strip_target(config: dict) -> dict:
|
||||||
|
return {key: value for key, value in config.items() if key != "_target_"}
|
||||||
|
|
||||||
|
|
||||||
|
def infer_config_from_state_dict(state_dict: dict) -> dict:
|
||||||
|
action_dim = state_dict["action_encoder.patch_embed.weight"].shape[1]
|
||||||
|
return {
|
||||||
|
"encoder": {
|
||||||
|
"size": "tiny",
|
||||||
|
"patch_size": 14,
|
||||||
|
"image_size": 224,
|
||||||
|
"pretrained": False,
|
||||||
|
"use_mask_token": False,
|
||||||
|
},
|
||||||
|
"predictor": {
|
||||||
|
"num_frames": 3,
|
||||||
|
"input_dim": 192,
|
||||||
|
"hidden_dim": 192,
|
||||||
|
"output_dim": 192,
|
||||||
|
"depth": 6,
|
||||||
|
"heads": 16,
|
||||||
|
"mlp_dim": 2048,
|
||||||
|
"dim_head": 64,
|
||||||
|
"dropout": 0.1,
|
||||||
|
"emb_dropout": 0.0,
|
||||||
|
},
|
||||||
|
"action_encoder": {
|
||||||
|
"input_dim": action_dim,
|
||||||
|
"emb_dim": 192,
|
||||||
|
},
|
||||||
|
"projector": {
|
||||||
|
"input_dim": 192,
|
||||||
|
"output_dim": 192,
|
||||||
|
"hidden_dim": 2048,
|
||||||
|
},
|
||||||
|
"pred_proj": {
|
||||||
|
"input_dim": 192,
|
||||||
|
"output_dim": 192,
|
||||||
|
"hidden_dim": 2048,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def build_model(config: dict) -> JEPA:
|
||||||
|
encoder = spt.backbone.utils.vit_hf(**_strip_target(config["encoder"]))
|
||||||
|
predictor = ARPredictor(**_strip_target(config["predictor"]))
|
||||||
|
action_encoder = Embedder(**_strip_target(config["action_encoder"]))
|
||||||
|
|
||||||
|
projector_cfg = _strip_target(config["projector"])
|
||||||
|
projector_cfg["norm_fn"] = torch.nn.BatchNorm1d
|
||||||
|
projector = MLP(**projector_cfg)
|
||||||
|
|
||||||
|
pred_proj_cfg = _strip_target(config["pred_proj"])
|
||||||
|
pred_proj_cfg["norm_fn"] = torch.nn.BatchNorm1d
|
||||||
|
pred_proj = MLP(**pred_proj_cfg)
|
||||||
|
|
||||||
|
return JEPA(
|
||||||
|
encoder=encoder,
|
||||||
|
predictor=predictor,
|
||||||
|
action_encoder=action_encoder,
|
||||||
|
projector=projector,
|
||||||
|
pred_proj=pred_proj,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_checkpoint(input_dir: Path, output_name: str) -> tuple[Path, Path]:
|
||||||
|
config_path = input_dir / "config.json"
|
||||||
|
weights_path = input_dir / "weights.pt"
|
||||||
|
if not weights_path.exists():
|
||||||
|
raise FileNotFoundError(f"Missing weights file: {weights_path}")
|
||||||
|
|
||||||
|
state_dict = torch.load(weights_path, map_location="cpu")
|
||||||
|
config = _load_json(config_path) if config_path.exists() else infer_config_from_state_dict(state_dict)
|
||||||
|
model = build_model(config)
|
||||||
|
missing, unexpected = model.load_state_dict(state_dict, strict=True)
|
||||||
|
if missing or unexpected:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"State dict mismatch: missing={missing}, unexpected={unexpected}"
|
||||||
|
)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
object_path = input_dir / f"{output_name}_object.ckpt"
|
||||||
|
weight_path = input_dir / f"{output_name}_weight.ckpt"
|
||||||
|
torch.save(model, object_path)
|
||||||
|
torch.save(model.state_dict(), weight_path)
|
||||||
|
return object_path, weight_path
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"input_dir",
|
||||||
|
type=Path,
|
||||||
|
help="Directory containing weights.pt and optionally config.json.",
|
||||||
|
)
|
||||||
|
parser.add_argument("--output-name", default="lewm")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
object_path, weight_path = convert_checkpoint(args.input_dir, args.output_name)
|
||||||
|
print(f"wrote {object_path}")
|
||||||
|
print(f"wrote {weight_path}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
255
scripts/launch_multinode_eval.sh
Executable file
255
scripts/launch_multinode_eval.sh
Executable file
@@ -0,0 +1,255 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
# Launch 2-node LeWM evaluation from node-3.
|
||||||
|
#
|
||||||
|
# Defaults match the current cluster layout:
|
||||||
|
# node-3: 10.16.200.9, node_rank=0
|
||||||
|
# node-2: 10.16.200.8, node_rank=1
|
||||||
|
# Each node runs two local torchrun processes for two visible GPUs.
|
||||||
|
|
||||||
|
REPO_ROOT="${REPO_ROOT:-/home/lewm/lewm}"
|
||||||
|
REMOTE_HOST="${REMOTE_HOST:-lewm@10.16.200.8}"
|
||||||
|
MASTER_ADDR="${MASTER_ADDR:-10.16.200.9}"
|
||||||
|
MASTER_PORT="${MASTER_PORT:-29500}"
|
||||||
|
|
||||||
|
NNODES="${NNODES:-2}"
|
||||||
|
NPROC_PER_NODE="${NPROC_PER_NODE:-2}"
|
||||||
|
CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-0,1}"
|
||||||
|
STABLEWM_HOME="${STABLEWM_HOME:-/home/lewm/.stable-wm}"
|
||||||
|
|
||||||
|
CONFIG_NAME="${CONFIG_NAME:-pusht.yaml}"
|
||||||
|
POLICY="${POLICY:-pusht/lewm}"
|
||||||
|
OUTPUT_FILENAME="${OUTPUT_FILENAME:-pusht_multinode_results.txt}"
|
||||||
|
EXTRA_ARGS="${EXTRA_ARGS:-}"
|
||||||
|
DRY_RUN="${DRY_RUN:-0}"
|
||||||
|
TAIL_LOGS="${TAIL_LOGS:-1}"
|
||||||
|
PRELOAD_WAIT="${PRELOAD_WAIT:-0}"
|
||||||
|
PRELOAD_SIGNAL_FILE="${PRELOAD_SIGNAL_FILE:-/tmp/lewm_preload_start}"
|
||||||
|
PRELOAD_CLEAR_SIGNAL="${PRELOAD_CLEAR_SIGNAL:-1}"
|
||||||
|
|
||||||
|
LOG_DIR="${LOG_DIR:-${REPO_ROOT}/logs/multinode}"
|
||||||
|
mkdir -p "${LOG_DIR}"
|
||||||
|
RUN_ID="$(date +%Y%m%d_%H%M%S)"
|
||||||
|
LOCAL_LOG="${LOG_DIR}/${RUN_ID}_node3_rank0.log"
|
||||||
|
REMOTE_LOG="${LOG_DIR}/${RUN_ID}_node2_rank1.log"
|
||||||
|
|
||||||
|
SSH_OPTS=(
|
||||||
|
-F /dev/null
|
||||||
|
-o StrictHostKeyChecking=no
|
||||||
|
-o ServerAliveInterval=30
|
||||||
|
-o ServerAliveCountMax=20
|
||||||
|
)
|
||||||
|
|
||||||
|
COMMON_ARGS=(
|
||||||
|
"--config-name=${CONFIG_NAME}"
|
||||||
|
"policy=${POLICY}"
|
||||||
|
"multi_node.enabled=true"
|
||||||
|
"output.filename=${OUTPUT_FILENAME}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if [[ "${PRELOAD_WAIT}" == "1" ]]; then
|
||||||
|
COMMON_ARGS+=(
|
||||||
|
"preload_wait.enabled=true"
|
||||||
|
"preload_wait.file=${PRELOAD_SIGNAL_FILE}"
|
||||||
|
)
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [[ -n "${EXTRA_ARGS}" ]]; then
|
||||||
|
# shellcheck disable=SC2206
|
||||||
|
COMMON_ARGS+=(${EXTRA_ARGS})
|
||||||
|
fi
|
||||||
|
|
||||||
|
make_command() {
|
||||||
|
local node_rank="$1"
|
||||||
|
local repo_q cuda_q stablewm_q arg_q eval_args
|
||||||
|
printf -v repo_q '%q' "${REPO_ROOT}"
|
||||||
|
printf -v cuda_q '%q' "${CUDA_VISIBLE_DEVICES}"
|
||||||
|
printf -v stablewm_q '%q' "${STABLEWM_HOME}"
|
||||||
|
|
||||||
|
eval_args=""
|
||||||
|
for arg in "${COMMON_ARGS[@]}"; do
|
||||||
|
printf -v arg_q '%q' "${arg}"
|
||||||
|
eval_args+=" ${arg_q}"
|
||||||
|
done
|
||||||
|
|
||||||
|
printf 'cd %s && source .venv/bin/activate && export CUDA_VISIBLE_DEVICES=%s && export STABLEWM_HOME=%s && torchrun --nnodes=%q --nproc_per_node=%q --node_rank=%q --master_addr=%q --master_port=%q eval.py%s' \
|
||||||
|
"${repo_q}" \
|
||||||
|
"${cuda_q}" \
|
||||||
|
"${stablewm_q}" \
|
||||||
|
"${NNODES}" \
|
||||||
|
"${NPROC_PER_NODE}" \
|
||||||
|
"${node_rank}" \
|
||||||
|
"${MASTER_ADDR}" \
|
||||||
|
"${MASTER_PORT}" \
|
||||||
|
"${eval_args}"
|
||||||
|
}
|
||||||
|
|
||||||
|
REMOTE_CMD="$(make_command 1)"
|
||||||
|
LOCAL_CMD="$(make_command 0)"
|
||||||
|
printf -v REMOTE_CMD_Q '%q' "${REMOTE_CMD}"
|
||||||
|
|
||||||
|
REMOTE_PID=""
|
||||||
|
LOCAL_PID=""
|
||||||
|
LOCAL_TAIL_PID=""
|
||||||
|
REMOTE_TAIL_PID=""
|
||||||
|
REMOTE_CLEANUP_CMD=""
|
||||||
|
REMOTE_CLEANUP_CMD_Q=""
|
||||||
|
|
||||||
|
start_log_tail() {
|
||||||
|
local label="$1"
|
||||||
|
local log_file="$2"
|
||||||
|
local label_q log_q
|
||||||
|
|
||||||
|
printf -v label_q '%q' "${label}"
|
||||||
|
printf -v log_q '%q' "${log_file}"
|
||||||
|
setsid bash -lc "tail -n +1 -F ${log_q} 2>/dev/null | sed -u 's/^/[${label_q}] /'" &
|
||||||
|
}
|
||||||
|
|
||||||
|
stop_log_tails() {
|
||||||
|
local pid
|
||||||
|
for pid in "${LOCAL_TAIL_PID}" "${REMOTE_TAIL_PID}"; do
|
||||||
|
if [[ -n "${pid}" ]] && kill -0 "${pid}" 2>/dev/null; then
|
||||||
|
kill -TERM "-${pid}" 2>/dev/null || kill -TERM "${pid}" 2>/dev/null || true
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
}
|
||||||
|
|
||||||
|
remote_cleanup_command() {
|
||||||
|
local pattern_q
|
||||||
|
local patterns=(
|
||||||
|
"torchrun .*--master_addr=${MASTER_ADDR} .*--master_port=${MASTER_PORT} .*eval.py"
|
||||||
|
"torchrun .*--master_port=${MASTER_PORT} .*eval.py"
|
||||||
|
"python.*eval.py .*output.filename=${OUTPUT_FILENAME}"
|
||||||
|
)
|
||||||
|
|
||||||
|
printf 'set +e; '
|
||||||
|
for pattern in "${patterns[@]}"; do
|
||||||
|
printf -v pattern_q '%q' "${pattern}"
|
||||||
|
printf 'pkill -TERM -f %s 2>/dev/null; ' "${pattern_q}"
|
||||||
|
done
|
||||||
|
printf 'sleep 2; '
|
||||||
|
for pattern in "${patterns[@]}"; do
|
||||||
|
printf -v pattern_q '%q' "${pattern}"
|
||||||
|
printf 'pkill -KILL -f %s 2>/dev/null; ' "${pattern_q}"
|
||||||
|
done
|
||||||
|
printf 'true'
|
||||||
|
}
|
||||||
|
|
||||||
|
cleanup() {
|
||||||
|
local status="$?"
|
||||||
|
trap - INT TERM EXIT
|
||||||
|
|
||||||
|
if [[ "${status}" -eq 0 ]]; then
|
||||||
|
return 0
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo
|
||||||
|
echo "Stopping multi-node eval..."
|
||||||
|
stop_log_tails
|
||||||
|
|
||||||
|
if [[ -n "${LOCAL_PID}" ]] && kill -0 "${LOCAL_PID}" 2>/dev/null; then
|
||||||
|
kill -TERM "-${LOCAL_PID}" 2>/dev/null || kill -TERM "${LOCAL_PID}" 2>/dev/null || true
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [[ -n "${REMOTE_PID}" ]] && kill -0 "${REMOTE_PID}" 2>/dev/null; then
|
||||||
|
kill -TERM "${REMOTE_PID}" 2>/dev/null || true
|
||||||
|
fi
|
||||||
|
|
||||||
|
ssh "${SSH_OPTS[@]}" "${REMOTE_HOST}" "bash -lc ${REMOTE_CLEANUP_CMD_Q}" >/dev/null 2>&1 || true
|
||||||
|
|
||||||
|
if [[ -n "${LOCAL_PID}" ]] && kill -0 "${LOCAL_PID}" 2>/dev/null; then
|
||||||
|
sleep 2
|
||||||
|
kill -KILL "-${LOCAL_PID}" 2>/dev/null || kill -KILL "${LOCAL_PID}" 2>/dev/null || true
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "Cleanup requested. Check logs if any process was already exiting:"
|
||||||
|
echo " local: ${LOCAL_LOG}"
|
||||||
|
echo " remote: ${REMOTE_LOG}"
|
||||||
|
exit "${status}"
|
||||||
|
}
|
||||||
|
|
||||||
|
trap cleanup INT TERM EXIT
|
||||||
|
|
||||||
|
REMOTE_CLEANUP_CMD="$(remote_cleanup_command)"
|
||||||
|
printf -v REMOTE_CLEANUP_CMD_Q '%q' "${REMOTE_CLEANUP_CMD}"
|
||||||
|
|
||||||
|
echo "Launching multi-node eval"
|
||||||
|
echo " master: ${MASTER_ADDR}:${MASTER_PORT}"
|
||||||
|
echo " remote: ${REMOTE_HOST}"
|
||||||
|
echo " repo: ${REPO_ROOT}"
|
||||||
|
echo " stablewm: ${STABLEWM_HOME}"
|
||||||
|
echo " config: ${CONFIG_NAME}"
|
||||||
|
echo " policy: ${POLICY}"
|
||||||
|
echo " output: ${OUTPUT_FILENAME}"
|
||||||
|
echo " extra: ${EXTRA_ARGS:-<none>}"
|
||||||
|
echo " tail logs: ${TAIL_LOGS}"
|
||||||
|
echo " preload wait: ${PRELOAD_WAIT}"
|
||||||
|
if [[ "${PRELOAD_WAIT}" == "1" ]]; then
|
||||||
|
echo " preload signal: ${PRELOAD_SIGNAL_FILE}"
|
||||||
|
echo " start command: touch ${PRELOAD_SIGNAL_FILE}"
|
||||||
|
fi
|
||||||
|
echo " local log: ${LOCAL_LOG}"
|
||||||
|
echo " remote log: ${REMOTE_LOG}"
|
||||||
|
|
||||||
|
if [[ "${DRY_RUN}" == "1" ]]; then
|
||||||
|
echo
|
||||||
|
echo "Remote command:"
|
||||||
|
echo "ssh ${SSH_OPTS[*]} ${REMOTE_HOST} bash -lc ${REMOTE_CMD_Q}"
|
||||||
|
echo
|
||||||
|
echo "Local command:"
|
||||||
|
printf -v LOCAL_CMD_Q '%q' "${LOCAL_CMD}"
|
||||||
|
echo "bash -lc ${LOCAL_CMD_Q}"
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [[ "${PRELOAD_WAIT}" == "1" && "${PRELOAD_CLEAR_SIGNAL}" == "1" ]]; then
|
||||||
|
rm -f "${PRELOAD_SIGNAL_FILE}"
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "Starting remote node_rank=1..."
|
||||||
|
ssh "${SSH_OPTS[@]}" "${REMOTE_HOST}" "bash -lc ${REMOTE_CMD_Q}" >"${REMOTE_LOG}" 2>&1 &
|
||||||
|
REMOTE_PID="$!"
|
||||||
|
|
||||||
|
if [[ "${TAIL_LOGS}" == "1" ]]; then
|
||||||
|
start_log_tail "node2" "${REMOTE_LOG}"
|
||||||
|
REMOTE_TAIL_PID="$!"
|
||||||
|
fi
|
||||||
|
|
||||||
|
sleep 3
|
||||||
|
|
||||||
|
echo "Starting local node_rank=0..."
|
||||||
|
set +e
|
||||||
|
setsid bash -lc "${LOCAL_CMD}" >"${LOCAL_LOG}" 2>&1 &
|
||||||
|
LOCAL_PID="$!"
|
||||||
|
|
||||||
|
if [[ "${TAIL_LOGS}" == "1" ]]; then
|
||||||
|
start_log_tail "node3" "${LOCAL_LOG}"
|
||||||
|
LOCAL_TAIL_PID="$!"
|
||||||
|
fi
|
||||||
|
|
||||||
|
wait "${LOCAL_PID}"
|
||||||
|
LOCAL_STATUS="$?"
|
||||||
|
|
||||||
|
wait "${REMOTE_PID}"
|
||||||
|
REMOTE_STATUS="$?"
|
||||||
|
set -e
|
||||||
|
|
||||||
|
stop_log_tails
|
||||||
|
trap - INT TERM EXIT
|
||||||
|
|
||||||
|
echo "Local status: ${LOCAL_STATUS}"
|
||||||
|
echo "Remote status: ${REMOTE_STATUS}"
|
||||||
|
echo "Local log: ${LOCAL_LOG}"
|
||||||
|
echo "Remote log: ${REMOTE_LOG}"
|
||||||
|
|
||||||
|
if [[ "${LOCAL_STATUS}" -ne 0 || "${REMOTE_STATUS}" -ne 0 ]]; then
|
||||||
|
echo "Multi-node eval failed. Tail logs:"
|
||||||
|
echo "===== local tail ====="
|
||||||
|
tail -80 "${LOCAL_LOG}" || true
|
||||||
|
echo "===== remote tail ====="
|
||||||
|
tail -80 "${REMOTE_LOG}" || true
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "Multi-node eval complete."
|
||||||
111
scripts/warmup_eval.sh
Executable file
111
scripts/warmup_eval.sh
Executable file
@@ -0,0 +1,111 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
# Warm up LeWM evaluation before a formal run.
|
||||||
|
#
|
||||||
|
# This script intentionally does a small eval for each task so ROCm/PyTorch can
|
||||||
|
# initialize GPU contexts, compile predictor graphs, populate kernel caches, and
|
||||||
|
# touch dataset/checkpoint paths before the timed run.
|
||||||
|
#
|
||||||
|
# Site-specific things to check before using this at the competition:
|
||||||
|
# 1. STABLEWM_HOME points to the directory containing datasets/checkpoints.
|
||||||
|
# 2. The policy names below match the checkpoint folders at STABLEWM_HOME.
|
||||||
|
# 3. The dataset names in config/eval/*.yaml match the onsite dataset files.
|
||||||
|
# 4. The GPU visibility variables match the GPUs allocated to this job.
|
||||||
|
# 5. WARMUP_NUM_EVAL is close enough to the formal shape to trigger useful
|
||||||
|
# compilation, but small enough not to waste much time.
|
||||||
|
|
||||||
|
REPO_ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
|
||||||
|
cd "${REPO_ROOT}"
|
||||||
|
|
||||||
|
PYTHON_BIN="${PYTHON_BIN:-${REPO_ROOT}/.venv/bin/python}"
|
||||||
|
STABLEWM_HOME="${STABLEWM_HOME:-/mnt/ASC1637/stablewm}"
|
||||||
|
export STABLEWM_HOME
|
||||||
|
|
||||||
|
# If Slurm allocates multiple GPUs, set these to the allocated physical GPU ids.
|
||||||
|
# Example for physical GPU 2 and 3:
|
||||||
|
# ROCR_VISIBLE_DEVICES=2,3 HIP_VISIBLE_DEVICES=0,1 CUDA_VISIBLE_DEVICES=0,1
|
||||||
|
#
|
||||||
|
# Important ROCm detail:
|
||||||
|
# ROCR_VISIBLE_DEVICES uses physical ids.
|
||||||
|
# HIP_VISIBLE_DEVICES/CUDA_VISIBLE_DEVICES use ids after ROCR remapping.
|
||||||
|
export ROCR_VISIBLE_DEVICES="${ROCR_VISIBLE_DEVICES:-0}"
|
||||||
|
export HIP_VISIBLE_DEVICES="${HIP_VISIBLE_DEVICES:-0}"
|
||||||
|
export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-0}"
|
||||||
|
|
||||||
|
WARMUP_NUM_EVAL="${WARMUP_NUM_EVAL:-10}"
|
||||||
|
INFERENCE_PRECISION="${INFERENCE_PRECISION:-fp16}"
|
||||||
|
OUTPUT_DIR="${OUTPUT_DIR:-/tmp/lewm_warmup}"
|
||||||
|
mkdir -p "${OUTPUT_DIR}"
|
||||||
|
|
||||||
|
# Enable multi-GPU warmup by setting MULTI_GPU=1.
|
||||||
|
# MULTI_GPU_DEVICES are process-local ids, not physical ids after ROCR remapping.
|
||||||
|
# Example:
|
||||||
|
# ROCR_VISIBLE_DEVICES=2,3 HIP_VISIBLE_DEVICES=0,1 MULTI_GPU=1 MULTI_GPU_DEVICES='[0,1]'
|
||||||
|
MULTI_GPU="${MULTI_GPU:-0}"
|
||||||
|
MULTI_GPU_DEVICES="${MULTI_GPU_DEVICES:-[0,1]}"
|
||||||
|
MULTI_NODE="${MULTI_NODE:-0}"
|
||||||
|
|
||||||
|
# Multi-node warmup uses the same eval.py entrypoint under torchrun.
|
||||||
|
# Example:
|
||||||
|
# torchrun --nnodes=2 --nproc_per_node=2 --node_rank=0 --master_addr=<ip> --master_port=29500 \
|
||||||
|
# eval.py --config-name=pusht.yaml policy=pusht/lewm multi_node.enabled=true
|
||||||
|
# This script leaves multi-node launch to the caller.
|
||||||
|
|
||||||
|
COMMON_ARGS=(
|
||||||
|
"eval.num_eval=${WARMUP_NUM_EVAL}"
|
||||||
|
"inference_precision=${INFERENCE_PRECISION}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if [[ "${MULTI_GPU}" == "1" ]]; then
|
||||||
|
COMMON_ARGS+=(
|
||||||
|
"+multi_gpu.enabled=true"
|
||||||
|
"+multi_gpu.devices=${MULTI_GPU_DEVICES}"
|
||||||
|
)
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [[ "${MULTI_NODE}" == "1" ]]; then
|
||||||
|
COMMON_ARGS+=(
|
||||||
|
"multi_node.enabled=true"
|
||||||
|
)
|
||||||
|
fi
|
||||||
|
|
||||||
|
run_warmup() {
|
||||||
|
local config_name="$1"
|
||||||
|
local policy="$2"
|
||||||
|
local output_name="$3"
|
||||||
|
|
||||||
|
echo
|
||||||
|
echo "== Warmup ${config_name} policy=${policy} =="
|
||||||
|
"${PYTHON_BIN}" eval.py \
|
||||||
|
"--config-name=${config_name}" \
|
||||||
|
"policy=${policy}" \
|
||||||
|
"output.filename=${OUTPUT_DIR}/${output_name}" \
|
||||||
|
"${COMMON_ARGS[@]}"
|
||||||
|
}
|
||||||
|
|
||||||
|
echo "LeWM warmup"
|
||||||
|
echo " repo: ${REPO_ROOT}"
|
||||||
|
echo " python: ${PYTHON_BIN}"
|
||||||
|
echo " STABLEWM_HOME: ${STABLEWM_HOME}"
|
||||||
|
echo " ROCR_VISIBLE_DEVICES: ${ROCR_VISIBLE_DEVICES}"
|
||||||
|
echo " HIP_VISIBLE_DEVICES: ${HIP_VISIBLE_DEVICES}"
|
||||||
|
echo " CUDA_VISIBLE_DEVICES: ${CUDA_VISIBLE_DEVICES}"
|
||||||
|
echo " WARMUP_NUM_EVAL: ${WARMUP_NUM_EVAL}"
|
||||||
|
echo " INFERENCE_PRECISION: ${INFERENCE_PRECISION}"
|
||||||
|
echo " MULTI_GPU: ${MULTI_GPU}"
|
||||||
|
if [[ "${MULTI_GPU}" == "1" ]]; then
|
||||||
|
echo " MULTI_GPU_DEVICES: ${MULTI_GPU_DEVICES}"
|
||||||
|
fi
|
||||||
|
echo " MULTI_NODE: ${MULTI_NODE}"
|
||||||
|
|
||||||
|
# Defaults match the checkpoint names used in this repo. If onsite checkpoint
|
||||||
|
# folders differ, override by editing these calls or passing the equivalent
|
||||||
|
# eval.py command manually.
|
||||||
|
run_warmup "pusht.yaml" "pusht/lewm" "warmup_pusht.txt"
|
||||||
|
run_warmup "reacher.yaml" "reacher/lewm" "warmup_reacher.txt"
|
||||||
|
run_warmup "cube.yaml" "cube/lewm" "warmup_cube.txt"
|
||||||
|
run_warmup "tworoom.yaml" "tworoom/lewm" "warmup_tworoom.txt"
|
||||||
|
|
||||||
|
echo
|
||||||
|
echo "Warmup complete. Logs were appended under ${OUTPUT_DIR}."
|
||||||
52
sth.md
Normal file
52
sth.md
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
我建议优先做这 4 类,都是跨数据集成立的:
|
||||||
|
|
||||||
|
1. 压 rollout 内环实现
|
||||||
|
见 jepa.py:127。现在每步都在做 action_encoder、切片、torch.cat、小规
|
||||||
|
模 predict 调用,这种碎片化实现对任何任务都亏。
|
||||||
|
通用改法:
|
||||||
|
|
||||||
|
- 整条 action_sequence 一次性做 action_encoder
|
||||||
|
- emb_hist / act_emb_hist 改成预分配 buffer
|
||||||
|
- 循环里只做索引覆盖或 copy_
|
||||||
|
- 去掉循环内 torch.cat
|
||||||
|
|
||||||
|
2. 减少热路径里的搬运和同步
|
||||||
|
profile 里 aten::copy_ 很重,这不是 TwoRoom 特有问题。重点看
|
||||||
|
jepa.py:67 和 jepa.py:186。
|
||||||
|
通用目标:
|
||||||
|
|
||||||
|
- 模型侧张量尽量全程留在 GPU
|
||||||
|
- 避免热路径反复 .to(device) / 隐式 layout 修复
|
||||||
|
- 到必须和环境交互的边界再一次性转 CPU / numpy
|
||||||
|
- 确保进入 predictor 的张量是 contiguous 的,少触发隐式 copy
|
||||||
|
|
||||||
|
3. 把编译成本移出正式计时
|
||||||
|
现在 torch.compile 默认开在 predictor,见 eval.py:70。102s -> 45s 很
|
||||||
|
像首轮编译预热。
|
||||||
|
通用做法:
|
||||||
|
|
||||||
|
- 在正式 start_time 前做一次 dummy predict 或 dummy rollout
|
||||||
|
- 保留只编译 predictor/predict,不要编译整个 solver
|
||||||
|
|
||||||
|
4. 减少临时对象和 shape bookkeeping
|
||||||
|
这是所有任务都会受益的。
|
||||||
|
重点看:
|
||||||
|
|
||||||
|
- jepa.py:100 到 jepa.py:106
|
||||||
|
- jepa.py:143 到 jepa.py:148
|
||||||
|
方向是:
|
||||||
|
- 能循环外做的 reshape,不放循环里
|
||||||
|
- 能原地更新,不新建张量
|
||||||
|
- 少做 dict 字段增删和中间容器组装
|
||||||
|
|
||||||
|
不建议优先做的通用性较差方案:
|
||||||
|
|
||||||
|
- 调 TwoRoom 专属 cache 规则
|
||||||
|
- 改数据集采样逻辑
|
||||||
|
- 按小数据集特点缩短 horizon
|
||||||
|
- 直接改 CEM 超参当“优化”
|
||||||
|
|
||||||
|
如果你要我直接开始改,我建议第一批只做两件事:
|
||||||
|
|
||||||
|
- 重写 jepa.py:127 这段 rollout,去掉循环内 action_encoder + cat
|
||||||
|
- 在 eval.py:306 前加 compile warmup
|
||||||
Reference in New Issue
Block a user