更改求解器 step=150时 成功率更高 step=125时 速度更快 成功率持平

This commit is contained in:
qihuanye
2026-05-04 07:55:13 +00:00
parent 4c3fdbcce6
commit cf43af0729
8 changed files with 558 additions and 3 deletions

1
.gitignore vendored
View File

@@ -1,4 +1,3 @@
.venv/
outputs/
__pycache__/
*.py[cod]

0
.venv/.gitignore vendored Normal file
View File

View File

@@ -0,0 +1,252 @@
"""Gradient-based solver for model-based planning."""
import time
from typing import Any
import gymnasium as gym
import numpy as np
import torch
from gymnasium.spaces import Box
from loguru import logger as logging
from .solver import Costable
class GradientSolver(torch.nn.Module):
"""Gradient-based solver using backpropagation through the world model.
Args:
model: World model implementing the Costable protocol.
n_steps: Number of gradient descent iterations.
batch_size: Number of environments to process in parallel.
var_scale: Initial variance scale for action perturbations.
num_samples: Number of action samples to optimize in parallel.
action_noise: Noise added to actions during optimization.
device: Device for tensor computations.
seed: Random seed for reproducibility.
optimizer_cls: PyTorch optimizer class to use.
optimizer_kwargs: Keyword arguments for the optimizer.
"""
def __init__(
self,
model: Costable,
n_steps: int,
batch_size: int | None = None,
var_scale: float = 1,
num_samples: int = 1,
action_noise: float = 0.0,
device: str | torch.device = 'cpu',
seed: int = 1234,
optimizer_cls: type[torch.optim.Optimizer] = torch.optim.SGD,
optimizer_kwargs: dict | None = None,
) -> None:
super().__init__()
self.model = model
self.n_steps = n_steps
self.batch_size = batch_size
self.num_samples = num_samples
self.var_scale = var_scale
self.action_noise = action_noise
self.device = device
self.torch_gen = torch.Generator(device=device).manual_seed(seed)
self.optimizer_cls = optimizer_cls
self.optimizer_kwargs = (
optimizer_kwargs if optimizer_kwargs is not None else {'lr': 1.0}
)
self._configured = False
self._n_envs = None
self._action_dim = None
self._config = None
def configure(
self, *, action_space: gym.Space, n_envs: int, config: Any
) -> None:
"""Configure the solver with environment specifications."""
self._action_space = action_space
self._n_envs = n_envs
self._config = config
self._action_dim = int(np.prod(action_space.shape[1:]))
self._configured = True
if not isinstance(action_space, Box):
logging.warning(
f'Action space is discrete, got {type(action_space)}. GradientSolver may not work as expected.'
)
@property
def n_envs(self) -> int:
"""Number of parallel environments."""
return self._n_envs
@property
def action_dim(self) -> int:
"""Flattened action dimension including action_block grouping."""
return self._action_dim * self._config.action_block
@property
def horizon(self) -> int:
"""Planning horizon in timesteps."""
return self._config.horizon
def __call__(self, *args: Any, **kwargs: Any) -> dict:
"""Make solver callable, forwarding to solve()."""
return self.solve(*args, **kwargs)
def init_action(self, actions: torch.Tensor | None = None) -> None:
"""Initialize the action tensor for optimization."""
device = torch.device(self.device)
if actions is None:
actions = torch.zeros(
(self._n_envs, 0, self.action_dim), device=device
)
elif actions.device != device:
actions = actions.to(device, non_blocking=True)
# fill remaining action
remaining = self.horizon - actions.shape[1]
if remaining > 0:
new_actions = torch.zeros(
self._n_envs,
remaining,
self.action_dim,
device=actions.device,
dtype=actions.dtype,
)
actions = torch.cat([actions, new_actions], dim=1)
actions = actions.unsqueeze(1).repeat_interleave(
self.num_samples, dim=1
) # add sample dim
actions[:, 1:] += (
torch.randn(
actions[:, 1:].shape,
generator=self.torch_gen,
device=self.device,
)
* self.var_scale
) # add small noise to all samples except the first one
# reset actions
if hasattr(self, 'init'):
self.init.copy_(actions)
else:
self.register_parameter('init', torch.nn.Parameter(actions))
def solve(
self, info_dict: dict, init_action: torch.Tensor | None = None
) -> dict:
"""Solve the planning problem using gradient descent."""
start_time = time.time()
outputs = {
'cost': [], # Will store list of cost histories per batch
'actions': None,
}
with torch.no_grad():
self.init_action(init_action)
# Determine batch size (default to all envs if not specified which can cause memory issues)
batch_size = (
self.batch_size if self.batch_size is not None else self.n_envs
)
total_envs = self.n_envs
# Lists to hold results from each batch to be concatenated later
batch_top_actions_list = []
# --- Outer Loop: Iterate over batches ---
for start_idx in range(0, total_envs, batch_size):
end_idx = min(start_idx + batch_size, total_envs)
current_bs = end_idx - start_idx
batch_init = self.init[start_idx:end_idx].clone().detach()
batch_init.requires_grad = True
# We initialize the optimizer class passed in __init__ with the kwargs
optim = self.optimizer_cls([batch_init], **self.optimizer_kwargs)
# Prepare Batch Infos
# Slice the input info_dict and then expand dimensions
expanded_infos = {}
for k, v in info_dict.items():
# Slice the data for the current batch indices
# Assumes input data dim 0 corresponds to n_envs
if torch.is_tensor(v):
batch_v = v[start_idx:end_idx]
batch_v = batch_v.unsqueeze(1)
batch_v = batch_v.expand(
current_bs, self.num_samples, *batch_v.shape[2:]
)
elif isinstance(v, np.ndarray):
batch_v = v[start_idx:end_idx]
batch_v = np.repeat(
batch_v[:, None, ...], self.num_samples, axis=1
)
expanded_infos[k] = batch_v
final_batch_cost = None
for step in range(self.n_steps):
current_info = expanded_infos.copy()
# Calculate cost using the batch parameter
costs = self.model.get_cost(current_info, batch_init)
assert isinstance(costs, torch.Tensor), (
f'Got {type(costs)} cost, expect torch.Tensor'
)
assert (
costs.ndim == 2
and costs.shape[0] == current_bs
and costs.shape[1] == self.num_samples
), (
f'Cost should be of shape ({current_bs}, {self.num_samples}), got {costs.shape}'
)
assert costs.requires_grad, (
'Cost must requires_grad for GD solver.'
)
cost = costs.sum() # Sum cost for this batch
cost.backward()
optim.step()
optim.zero_grad(set_to_none=True)
# Add noise
if self.action_noise > 0:
batch_init.data += (
torch.randn(
batch_init.shape,
generator=self.torch_gen,
device=self.device,
)
* self.action_noise
)
final_batch_cost = costs.detach().min(dim=1).values
# Store cost history for this batch
outputs['cost'].append(final_batch_cost)
# Update the global self.init with the optimized batch values
with torch.no_grad():
self.init[start_idx:end_idx] = batch_init
top_idx = costs.argmin(dim=1)
batch_indices = torch.arange(current_bs, device=self.device)
top_actions_batch = batch_init[batch_indices, top_idx]
batch_top_actions_list.append(top_actions_batch.detach())
# Concatenate all batch results
outputs['actions'] = torch.cat(batch_top_actions_list, dim=0)
outputs['cost'] = torch.cat(outputs['cost']).cpu().tolist()
end_time = time.time()
print(
f'GradientSolver.solve completed in {end_time - start_time:.4f} seconds.'
)
return outputs

View File

@@ -1,6 +1,6 @@
defaults:
- launcher: local
- solver: cem
- solver: gradient
- _self_
world:

View File

@@ -0,0 +1,14 @@
_target_: stable_worldmodel.solver.GradientSolver
model: ???
# Original adam.yaml reference: n_steps=30, num_samples=100, batch_size=1, lr=0.1.
n_steps: 125
batch_size: 50
num_samples: 1
action_noise: 0
device: "cuda"
seed: ${seed}
optimizer_cls:
_target_: hydra.utils.get_class
path: torch.optim.AdamW
optimizer_kwargs:
lr: 0.05

View File

@@ -153,6 +153,12 @@ def get_inference_context(cfg, device):
)
def get_eval_grad_context(solver=None):
if isinstance(solver, swm.solver.GradientSolver):
return torch.enable_grad()
return torch.inference_mode()
def make_profiler(cfg, results_path):
profile_cfg = get_profile_cfg(cfg)
if not profile_cfg["enabled"]:
@@ -345,6 +351,7 @@ def run_eval_subset(
)
else:
policy = swm.policy.RandomPolicy()
solver = None
inference_ctx = nullcontext()
inference_precision = "fp32"
compile_cfg = get_compile_cfg(local_cfg)
@@ -357,7 +364,7 @@ def run_eval_subset(
if str(device).startswith("cuda") and torch.cuda.is_available():
torch.cuda.synchronize()
start_time = time.time()
with torch.inference_mode():
with get_eval_grad_context(solver):
with profiler_ctx as profiler:
with inference_ctx:
with torch.profiler.record_function("eval.world_evaluate_from_dataset"):

21
jepa.py
View File

@@ -189,6 +189,27 @@ class JEPA(nn.Module):
pred_rollout = self.predict(emb_hist, act_emb_hist)[:, -1:]
else:
if torch.is_grad_enabled() and action_sequence.requires_grad:
emb_slots = init_hist.split(1, dim=1)
act_slots = act_hist.split(1, dim=1)
for t in range(act_future.size(1)):
emb_view = torch.cat(emb_slots[-HS:], dim=1)
act_view = torch.cat(act_slots[-HS:], dim=1)
pred_emb = self.predict(emb_view, act_view)[:, -1:]
next_act_emb = act_future[:, t : t + 1]
emb_slots = (*emb_slots[-(HS - 1) :], pred_emb)
act_slots = (*act_slots[-(HS - 1) :], next_act_emb)
emb_view = torch.cat(emb_slots[-HS:], dim=1)
act_view = torch.cat(act_slots[-HS:], dim=1)
pred_rollout = self.predict(emb_view, act_view)[:, -1:]
info["predicted_emb"] = pred_rollout.reshape(
B, S, *pred_rollout.shape[1:]
)
return info
emb_hist = init_hist.new_empty((B * S, HS, init_hist.size(-1)))
act_emb_hist = action_emb.new_empty((B * S, HS, action_emb.size(-1)))
emb_hist[:, :hist_len].copy_(init_hist)

View File

@@ -1228,3 +1228,265 @@ evaluation_time: 16.38284730911255 seconds
inference_precision: fp16
inference_compile_target: predictor
inference_compile_mode: reduce-overhead
==== CONFIG ====
cache_dir: null
solver:
_target_: stable_worldmodel.solver.CEMSolver
model: ???
batch_size: 8
num_samples: 64
var_scale: 1.0
n_steps: 10
topk: 8
device: cuda
seed: ${seed}
world:
env_name: swm/PushT-v1
num_envs: ${eval.num_eval}
max_episode_steps: ???
history_size: 1
frame_skip: 1
dataset:
stats: ${eval.dataset_name}
keys_to_cache:
- action
- proprio
- state
seed: 42
policy: pusht/lewm
inference_precision: fp16
plan_config:
horizon: 5
receding_horizon: 5
action_block: 5
eval:
num_eval: 50
goal_offset_steps: 25
eval_budget: 50
img_size: 224
save_video: false
dataset_name: pusht_expert_train
callables:
- method: _set_state
args:
state:
value: state
- method: _set_goal_state
args:
goal_state:
value: goal_state
output:
filename: pusht_results.txt
==== RESULTS ====
metrics: {'success_rate': 90.0, 'episode_successes': array([ True, True, True, True, True, True, True, True, True,
True, True, True, True, True, False, False, True, True,
True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, False,
True, True, True, True, True, True, False, True, True,
True, True, True, False, True]), 'seeds': None}
evaluation_time: 16.081845998764038 seconds
inference_precision: fp16
inference_compile_target: predictor
inference_compile_mode: reduce-overhead
==== CONFIG ====
cache_dir: null
solver:
_target_: stable_worldmodel.solver.GradientSolver
model: ???
n_steps: 10
batch_size: 8
num_samples: 1
action_noise: 0
device: cuda
seed: ${seed}
optimizer_cls:
_target_: hydra.utils.get_class
path: torch.optim.AdamW
optimizer_kwargs:
lr: 0.05
world:
env_name: swm/PushT-v1
num_envs: ${eval.num_eval}
max_episode_steps: ???
history_size: 1
frame_skip: 1
dataset:
stats: ${eval.dataset_name}
keys_to_cache:
- action
- proprio
- state
seed: 42
policy: pusht/lewm
inference_precision: fp16
plan_config:
horizon: 5
receding_horizon: 5
action_block: 5
eval:
num_eval: 50
goal_offset_steps: 25
eval_budget: 50
img_size: 224
save_video: false
dataset_name: pusht_expert_train
callables:
- method: _set_state
args:
state:
value: state
- method: _set_goal_state
args:
goal_state:
value: goal_state
output:
filename: pusht_results.txt
==== RESULTS ====
metrics: {'success_rate': 46.0, 'episode_successes': array([False, False, True, False, True, True, True, False, False,
True, False, False, True, False, False, False, False, False,
True, True, False, True, False, True, True, False, True,
False, True, True, True, False, False, True, False, False,
True, True, True, False, False, False, False, True, True,
True, True, False, False, False]), 'seeds': None}
evaluation_time: 63.84614443778992 seconds
inference_precision: fp16
inference_compile_target: predictor
inference_compile_mode: reduce-overhead
==== CONFIG ====
cache_dir: null
solver:
_target_: stable_worldmodel.solver.GradientSolver
model: ???
n_steps: 125
batch_size: 50
num_samples: 1
action_noise: 0
device: cuda
seed: ${seed}
optimizer_cls:
_target_: hydra.utils.get_class
path: torch.optim.AdamW
optimizer_kwargs:
lr: 0.05
world:
env_name: swm/PushT-v1
num_envs: ${eval.num_eval}
max_episode_steps: ???
history_size: 1
frame_skip: 1
dataset:
stats: ${eval.dataset_name}
keys_to_cache:
- action
- proprio
- state
seed: 42
policy: pusht/lewm
inference_precision: fp16
plan_config:
horizon: 5
receding_horizon: 5
action_block: 5
eval:
num_eval: 50
goal_offset_steps: 25
eval_budget: 50
img_size: 224
save_video: false
dataset_name: pusht_expert_train
callables:
- method: _set_state
args:
state:
value: state
- method: _set_goal_state
args:
goal_state:
value: goal_state
output:
filename: pusht_results.txt
profile:
enabled: false
==== RESULTS ====
metrics: {'success_rate': 90.0, 'episode_successes': array([ True, False, True, True, True, True, True, True, True,
True, False, True, True, True, True, True, False, True,
True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True,
True, True, True, False, True, True, False, True, True,
True, True, True, True, True]), 'seeds': None}
evaluation_time: 15.638921022415161 seconds
inference_precision: fp16
inference_compile_target: predictor
inference_compile_mode: reduce-overhead
==== CONFIG ====
cache_dir: null
solver:
_target_: stable_worldmodel.solver.GradientSolver
model: ???
n_steps: 125
batch_size: 50
num_samples: 1
action_noise: 0
device: cuda
seed: ${seed}
optimizer_cls:
_target_: hydra.utils.get_class
path: torch.optim.AdamW
optimizer_kwargs:
lr: 0.05
world:
env_name: swm/PushT-v1
num_envs: ${eval.num_eval}
max_episode_steps: ???
history_size: 1
frame_skip: 1
dataset:
stats: ${eval.dataset_name}
keys_to_cache:
- action
- proprio
- state
seed: 42
policy: pusht/lewm
inference_precision: fp16
plan_config:
horizon: 5
receding_horizon: 5
action_block: 5
eval:
num_eval: 50
goal_offset_steps: 25
eval_budget: 50
img_size: 224
save_video: false
dataset_name: pusht_expert_train
callables:
- method: _set_state
args:
state:
value: state
- method: _set_goal_state
args:
goal_state:
value: goal_state
output:
filename: pusht_results.txt
==== RESULTS ====
metrics: {'success_rate': 90.0, 'episode_successes': array([ True, False, True, True, True, True, True, True, True,
True, False, True, True, True, True, True, False, True,
True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True,
True, True, True, False, True, True, False, True, True,
True, True, True, True, True]), 'seeds': None}
evaluation_time: 16.060093879699707 seconds
inference_precision: fp16
inference_compile_target: predictor
inference_compile_mode: reduce-overhead