diff --git a/.gitignore b/.gitignore index 716a108..0c5b18d 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,3 @@ -.venv/ outputs/ __pycache__/ *.py[cod] diff --git a/.venv/.gitignore b/.venv/.gitignore new file mode 100644 index 0000000..e69de29 diff --git a/.venv/lib/python3.10/site-packages/stable_worldmodel/solver/gd.py b/.venv/lib/python3.10/site-packages/stable_worldmodel/solver/gd.py new file mode 100644 index 0000000..a183b7b --- /dev/null +++ b/.venv/lib/python3.10/site-packages/stable_worldmodel/solver/gd.py @@ -0,0 +1,252 @@ +"""Gradient-based solver for model-based planning.""" + +import time +from typing import Any + +import gymnasium as gym +import numpy as np +import torch +from gymnasium.spaces import Box +from loguru import logger as logging + +from .solver import Costable + + +class GradientSolver(torch.nn.Module): + """Gradient-based solver using backpropagation through the world model. + + Args: + model: World model implementing the Costable protocol. + n_steps: Number of gradient descent iterations. + batch_size: Number of environments to process in parallel. + var_scale: Initial variance scale for action perturbations. + num_samples: Number of action samples to optimize in parallel. + action_noise: Noise added to actions during optimization. + device: Device for tensor computations. + seed: Random seed for reproducibility. + optimizer_cls: PyTorch optimizer class to use. + optimizer_kwargs: Keyword arguments for the optimizer. + """ + + def __init__( + self, + model: Costable, + n_steps: int, + batch_size: int | None = None, + var_scale: float = 1, + num_samples: int = 1, + action_noise: float = 0.0, + device: str | torch.device = 'cpu', + seed: int = 1234, + optimizer_cls: type[torch.optim.Optimizer] = torch.optim.SGD, + optimizer_kwargs: dict | None = None, + ) -> None: + super().__init__() + self.model = model + self.n_steps = n_steps + self.batch_size = batch_size + self.num_samples = num_samples + self.var_scale = var_scale + self.action_noise = action_noise + self.device = device + self.torch_gen = torch.Generator(device=device).manual_seed(seed) + + self.optimizer_cls = optimizer_cls + self.optimizer_kwargs = ( + optimizer_kwargs if optimizer_kwargs is not None else {'lr': 1.0} + ) + + self._configured = False + self._n_envs = None + self._action_dim = None + self._config = None + + def configure( + self, *, action_space: gym.Space, n_envs: int, config: Any + ) -> None: + """Configure the solver with environment specifications.""" + self._action_space = action_space + self._n_envs = n_envs + self._config = config + self._action_dim = int(np.prod(action_space.shape[1:])) + self._configured = True + + if not isinstance(action_space, Box): + logging.warning( + f'Action space is discrete, got {type(action_space)}. GradientSolver may not work as expected.' + ) + + @property + def n_envs(self) -> int: + """Number of parallel environments.""" + return self._n_envs + + @property + def action_dim(self) -> int: + """Flattened action dimension including action_block grouping.""" + return self._action_dim * self._config.action_block + + @property + def horizon(self) -> int: + """Planning horizon in timesteps.""" + return self._config.horizon + + def __call__(self, *args: Any, **kwargs: Any) -> dict: + """Make solver callable, forwarding to solve().""" + return self.solve(*args, **kwargs) + + def init_action(self, actions: torch.Tensor | None = None) -> None: + """Initialize the action tensor for optimization.""" + device = torch.device(self.device) + if actions is None: + actions = torch.zeros( + (self._n_envs, 0, self.action_dim), device=device + ) + elif actions.device != device: + actions = actions.to(device, non_blocking=True) + + # fill remaining action + remaining = self.horizon - actions.shape[1] + + if remaining > 0: + new_actions = torch.zeros( + self._n_envs, + remaining, + self.action_dim, + device=actions.device, + dtype=actions.dtype, + ) + actions = torch.cat([actions, new_actions], dim=1) + + actions = actions.unsqueeze(1).repeat_interleave( + self.num_samples, dim=1 + ) # add sample dim + actions[:, 1:] += ( + torch.randn( + actions[:, 1:].shape, + generator=self.torch_gen, + device=self.device, + ) + * self.var_scale + ) # add small noise to all samples except the first one + + # reset actions + if hasattr(self, 'init'): + self.init.copy_(actions) + else: + self.register_parameter('init', torch.nn.Parameter(actions)) + + def solve( + self, info_dict: dict, init_action: torch.Tensor | None = None + ) -> dict: + """Solve the planning problem using gradient descent.""" + start_time = time.time() + outputs = { + 'cost': [], # Will store list of cost histories per batch + 'actions': None, + } + + with torch.no_grad(): + self.init_action(init_action) + + # Determine batch size (default to all envs if not specified which can cause memory issues) + batch_size = ( + self.batch_size if self.batch_size is not None else self.n_envs + ) + total_envs = self.n_envs + + # Lists to hold results from each batch to be concatenated later + batch_top_actions_list = [] + + # --- Outer Loop: Iterate over batches --- + for start_idx in range(0, total_envs, batch_size): + end_idx = min(start_idx + batch_size, total_envs) + current_bs = end_idx - start_idx + + batch_init = self.init[start_idx:end_idx].clone().detach() + batch_init.requires_grad = True + + # We initialize the optimizer class passed in __init__ with the kwargs + optim = self.optimizer_cls([batch_init], **self.optimizer_kwargs) + + # Prepare Batch Infos + # Slice the input info_dict and then expand dimensions + expanded_infos = {} + for k, v in info_dict.items(): + # Slice the data for the current batch indices + # Assumes input data dim 0 corresponds to n_envs + if torch.is_tensor(v): + batch_v = v[start_idx:end_idx] + batch_v = batch_v.unsqueeze(1) + batch_v = batch_v.expand( + current_bs, self.num_samples, *batch_v.shape[2:] + ) + elif isinstance(v, np.ndarray): + batch_v = v[start_idx:end_idx] + batch_v = np.repeat( + batch_v[:, None, ...], self.num_samples, axis=1 + ) + expanded_infos[k] = batch_v + + final_batch_cost = None + + for step in range(self.n_steps): + current_info = expanded_infos.copy() + + # Calculate cost using the batch parameter + costs = self.model.get_cost(current_info, batch_init) + + assert isinstance(costs, torch.Tensor), ( + f'Got {type(costs)} cost, expect torch.Tensor' + ) + assert ( + costs.ndim == 2 + and costs.shape[0] == current_bs + and costs.shape[1] == self.num_samples + ), ( + f'Cost should be of shape ({current_bs}, {self.num_samples}), got {costs.shape}' + ) + assert costs.requires_grad, ( + 'Cost must requires_grad for GD solver.' + ) + + cost = costs.sum() # Sum cost for this batch + cost.backward() + optim.step() + optim.zero_grad(set_to_none=True) + + # Add noise + if self.action_noise > 0: + batch_init.data += ( + torch.randn( + batch_init.shape, + generator=self.torch_gen, + device=self.device, + ) + * self.action_noise + ) + + final_batch_cost = costs.detach().min(dim=1).values + + # Store cost history for this batch + outputs['cost'].append(final_batch_cost) + + # Update the global self.init with the optimized batch values + with torch.no_grad(): + self.init[start_idx:end_idx] = batch_init + + top_idx = costs.argmin(dim=1) + batch_indices = torch.arange(current_bs, device=self.device) + + top_actions_batch = batch_init[batch_indices, top_idx] + batch_top_actions_list.append(top_actions_batch.detach()) + + # Concatenate all batch results + outputs['actions'] = torch.cat(batch_top_actions_list, dim=0) + outputs['cost'] = torch.cat(outputs['cost']).cpu().tolist() + end_time = time.time() + print( + f'GradientSolver.solve completed in {end_time - start_time:.4f} seconds.' + ) + + return outputs diff --git a/config/eval/pusht.yaml b/config/eval/pusht.yaml index e92ddeb..70f1907 100644 --- a/config/eval/pusht.yaml +++ b/config/eval/pusht.yaml @@ -1,6 +1,6 @@ defaults: - launcher: local - - solver: cem + - solver: gradient - _self_ world: diff --git a/config/eval/solver/gradient.yaml b/config/eval/solver/gradient.yaml new file mode 100644 index 0000000..c3e4094 --- /dev/null +++ b/config/eval/solver/gradient.yaml @@ -0,0 +1,14 @@ +_target_: stable_worldmodel.solver.GradientSolver +model: ??? +# Original adam.yaml reference: n_steps=30, num_samples=100, batch_size=1, lr=0.1. +n_steps: 125 +batch_size: 50 +num_samples: 1 +action_noise: 0 +device: "cuda" +seed: ${seed} +optimizer_cls: + _target_: hydra.utils.get_class + path: torch.optim.AdamW +optimizer_kwargs: + lr: 0.05 diff --git a/eval.py b/eval.py index 732d83a..7c487bb 100644 --- a/eval.py +++ b/eval.py @@ -153,6 +153,12 @@ def get_inference_context(cfg, device): ) +def get_eval_grad_context(solver=None): + if isinstance(solver, swm.solver.GradientSolver): + return torch.enable_grad() + return torch.inference_mode() + + def make_profiler(cfg, results_path): profile_cfg = get_profile_cfg(cfg) if not profile_cfg["enabled"]: @@ -345,6 +351,7 @@ def run_eval_subset( ) else: policy = swm.policy.RandomPolicy() + solver = None inference_ctx = nullcontext() inference_precision = "fp32" compile_cfg = get_compile_cfg(local_cfg) @@ -357,7 +364,7 @@ def run_eval_subset( if str(device).startswith("cuda") and torch.cuda.is_available(): torch.cuda.synchronize() start_time = time.time() - with torch.inference_mode(): + with get_eval_grad_context(solver): with profiler_ctx as profiler: with inference_ctx: with torch.profiler.record_function("eval.world_evaluate_from_dataset"): diff --git a/jepa.py b/jepa.py index 368856e..dc7238e 100644 --- a/jepa.py +++ b/jepa.py @@ -189,6 +189,27 @@ class JEPA(nn.Module): pred_rollout = self.predict(emb_hist, act_emb_hist)[:, -1:] else: + if torch.is_grad_enabled() and action_sequence.requires_grad: + emb_slots = init_hist.split(1, dim=1) + act_slots = act_hist.split(1, dim=1) + + for t in range(act_future.size(1)): + emb_view = torch.cat(emb_slots[-HS:], dim=1) + act_view = torch.cat(act_slots[-HS:], dim=1) + pred_emb = self.predict(emb_view, act_view)[:, -1:] + next_act_emb = act_future[:, t : t + 1] + + emb_slots = (*emb_slots[-(HS - 1) :], pred_emb) + act_slots = (*act_slots[-(HS - 1) :], next_act_emb) + + emb_view = torch.cat(emb_slots[-HS:], dim=1) + act_view = torch.cat(act_slots[-HS:], dim=1) + pred_rollout = self.predict(emb_view, act_view)[:, -1:] + info["predicted_emb"] = pred_rollout.reshape( + B, S, *pred_rollout.shape[1:] + ) + return info + emb_hist = init_hist.new_empty((B * S, HS, init_hist.size(-1))) act_emb_hist = action_emb.new_empty((B * S, HS, action_emb.size(-1))) emb_hist[:, :hist_len].copy_(init_hist) diff --git a/pusht_results.txt b/pusht_results.txt index 93375e4..bc865a6 100644 --- a/pusht_results.txt +++ b/pusht_results.txt @@ -1228,3 +1228,265 @@ evaluation_time: 16.38284730911255 seconds inference_precision: fp16 inference_compile_target: predictor inference_compile_mode: reduce-overhead + +==== CONFIG ==== +cache_dir: null +solver: + _target_: stable_worldmodel.solver.CEMSolver + model: ??? + batch_size: 8 + num_samples: 64 + var_scale: 1.0 + n_steps: 10 + topk: 8 + device: cuda + seed: ${seed} +world: + env_name: swm/PushT-v1 + num_envs: ${eval.num_eval} + max_episode_steps: ??? + history_size: 1 + frame_skip: 1 +dataset: + stats: ${eval.dataset_name} + keys_to_cache: + - action + - proprio + - state +seed: 42 +policy: pusht/lewm +inference_precision: fp16 +plan_config: + horizon: 5 + receding_horizon: 5 + action_block: 5 +eval: + num_eval: 50 + goal_offset_steps: 25 + eval_budget: 50 + img_size: 224 + save_video: false + dataset_name: pusht_expert_train + callables: + - method: _set_state + args: + state: + value: state + - method: _set_goal_state + args: + goal_state: + value: goal_state +output: + filename: pusht_results.txt + +==== RESULTS ==== +metrics: {'success_rate': 90.0, 'episode_successes': array([ True, True, True, True, True, True, True, True, True, + True, True, True, True, True, False, False, True, True, + True, True, True, True, True, True, True, True, True, + True, True, True, True, True, True, True, True, False, + True, True, True, True, True, True, False, True, True, + True, True, True, False, True]), 'seeds': None} +evaluation_time: 16.081845998764038 seconds +inference_precision: fp16 +inference_compile_target: predictor +inference_compile_mode: reduce-overhead + +==== CONFIG ==== +cache_dir: null +solver: + _target_: stable_worldmodel.solver.GradientSolver + model: ??? + n_steps: 10 + batch_size: 8 + num_samples: 1 + action_noise: 0 + device: cuda + seed: ${seed} + optimizer_cls: + _target_: hydra.utils.get_class + path: torch.optim.AdamW + optimizer_kwargs: + lr: 0.05 +world: + env_name: swm/PushT-v1 + num_envs: ${eval.num_eval} + max_episode_steps: ??? + history_size: 1 + frame_skip: 1 +dataset: + stats: ${eval.dataset_name} + keys_to_cache: + - action + - proprio + - state +seed: 42 +policy: pusht/lewm +inference_precision: fp16 +plan_config: + horizon: 5 + receding_horizon: 5 + action_block: 5 +eval: + num_eval: 50 + goal_offset_steps: 25 + eval_budget: 50 + img_size: 224 + save_video: false + dataset_name: pusht_expert_train + callables: + - method: _set_state + args: + state: + value: state + - method: _set_goal_state + args: + goal_state: + value: goal_state +output: + filename: pusht_results.txt + +==== RESULTS ==== +metrics: {'success_rate': 46.0, 'episode_successes': array([False, False, True, False, True, True, True, False, False, + True, False, False, True, False, False, False, False, False, + True, True, False, True, False, True, True, False, True, + False, True, True, True, False, False, True, False, False, + True, True, True, False, False, False, False, True, True, + True, True, False, False, False]), 'seeds': None} +evaluation_time: 63.84614443778992 seconds +inference_precision: fp16 +inference_compile_target: predictor +inference_compile_mode: reduce-overhead + +==== CONFIG ==== +cache_dir: null +solver: + _target_: stable_worldmodel.solver.GradientSolver + model: ??? + n_steps: 125 + batch_size: 50 + num_samples: 1 + action_noise: 0 + device: cuda + seed: ${seed} + optimizer_cls: + _target_: hydra.utils.get_class + path: torch.optim.AdamW + optimizer_kwargs: + lr: 0.05 +world: + env_name: swm/PushT-v1 + num_envs: ${eval.num_eval} + max_episode_steps: ??? + history_size: 1 + frame_skip: 1 +dataset: + stats: ${eval.dataset_name} + keys_to_cache: + - action + - proprio + - state +seed: 42 +policy: pusht/lewm +inference_precision: fp16 +plan_config: + horizon: 5 + receding_horizon: 5 + action_block: 5 +eval: + num_eval: 50 + goal_offset_steps: 25 + eval_budget: 50 + img_size: 224 + save_video: false + dataset_name: pusht_expert_train + callables: + - method: _set_state + args: + state: + value: state + - method: _set_goal_state + args: + goal_state: + value: goal_state +output: + filename: pusht_results.txt +profile: + enabled: false + +==== RESULTS ==== +metrics: {'success_rate': 90.0, 'episode_successes': array([ True, False, True, True, True, True, True, True, True, + True, False, True, True, True, True, True, False, True, + True, True, True, True, True, True, True, True, True, + True, True, True, True, True, True, True, True, True, + True, True, True, False, True, True, False, True, True, + True, True, True, True, True]), 'seeds': None} +evaluation_time: 15.638921022415161 seconds +inference_precision: fp16 +inference_compile_target: predictor +inference_compile_mode: reduce-overhead + +==== CONFIG ==== +cache_dir: null +solver: + _target_: stable_worldmodel.solver.GradientSolver + model: ??? + n_steps: 125 + batch_size: 50 + num_samples: 1 + action_noise: 0 + device: cuda + seed: ${seed} + optimizer_cls: + _target_: hydra.utils.get_class + path: torch.optim.AdamW + optimizer_kwargs: + lr: 0.05 +world: + env_name: swm/PushT-v1 + num_envs: ${eval.num_eval} + max_episode_steps: ??? + history_size: 1 + frame_skip: 1 +dataset: + stats: ${eval.dataset_name} + keys_to_cache: + - action + - proprio + - state +seed: 42 +policy: pusht/lewm +inference_precision: fp16 +plan_config: + horizon: 5 + receding_horizon: 5 + action_block: 5 +eval: + num_eval: 50 + goal_offset_steps: 25 + eval_budget: 50 + img_size: 224 + save_video: false + dataset_name: pusht_expert_train + callables: + - method: _set_state + args: + state: + value: state + - method: _set_goal_state + args: + goal_state: + value: goal_state +output: + filename: pusht_results.txt + +==== RESULTS ==== +metrics: {'success_rate': 90.0, 'episode_successes': array([ True, False, True, True, True, True, True, True, True, + True, False, True, True, True, True, True, False, True, + True, True, True, True, True, True, True, True, True, + True, True, True, True, True, True, True, True, True, + True, True, True, False, True, True, False, True, True, + True, True, True, True, True]), 'seeds': None} +evaluation_time: 16.060093879699707 seconds +inference_precision: fp16 +inference_compile_target: predictor +inference_compile_mode: reduce-overhead