更改求解器 step=150时 成功率更高 step=125时 速度更快 成功率持平
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -1,4 +1,3 @@
|
|||||||
.venv/
|
|
||||||
outputs/
|
outputs/
|
||||||
__pycache__/
|
__pycache__/
|
||||||
*.py[cod]
|
*.py[cod]
|
||||||
|
|||||||
0
.venv/.gitignore
vendored
Normal file
0
.venv/.gitignore
vendored
Normal file
@@ -0,0 +1,252 @@
|
|||||||
|
"""Gradient-based solver for model-based planning."""
|
||||||
|
|
||||||
|
import time
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import gymnasium as gym
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from gymnasium.spaces import Box
|
||||||
|
from loguru import logger as logging
|
||||||
|
|
||||||
|
from .solver import Costable
|
||||||
|
|
||||||
|
|
||||||
|
class GradientSolver(torch.nn.Module):
|
||||||
|
"""Gradient-based solver using backpropagation through the world model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: World model implementing the Costable protocol.
|
||||||
|
n_steps: Number of gradient descent iterations.
|
||||||
|
batch_size: Number of environments to process in parallel.
|
||||||
|
var_scale: Initial variance scale for action perturbations.
|
||||||
|
num_samples: Number of action samples to optimize in parallel.
|
||||||
|
action_noise: Noise added to actions during optimization.
|
||||||
|
device: Device for tensor computations.
|
||||||
|
seed: Random seed for reproducibility.
|
||||||
|
optimizer_cls: PyTorch optimizer class to use.
|
||||||
|
optimizer_kwargs: Keyword arguments for the optimizer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: Costable,
|
||||||
|
n_steps: int,
|
||||||
|
batch_size: int | None = None,
|
||||||
|
var_scale: float = 1,
|
||||||
|
num_samples: int = 1,
|
||||||
|
action_noise: float = 0.0,
|
||||||
|
device: str | torch.device = 'cpu',
|
||||||
|
seed: int = 1234,
|
||||||
|
optimizer_cls: type[torch.optim.Optimizer] = torch.optim.SGD,
|
||||||
|
optimizer_kwargs: dict | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.model = model
|
||||||
|
self.n_steps = n_steps
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.num_samples = num_samples
|
||||||
|
self.var_scale = var_scale
|
||||||
|
self.action_noise = action_noise
|
||||||
|
self.device = device
|
||||||
|
self.torch_gen = torch.Generator(device=device).manual_seed(seed)
|
||||||
|
|
||||||
|
self.optimizer_cls = optimizer_cls
|
||||||
|
self.optimizer_kwargs = (
|
||||||
|
optimizer_kwargs if optimizer_kwargs is not None else {'lr': 1.0}
|
||||||
|
)
|
||||||
|
|
||||||
|
self._configured = False
|
||||||
|
self._n_envs = None
|
||||||
|
self._action_dim = None
|
||||||
|
self._config = None
|
||||||
|
|
||||||
|
def configure(
|
||||||
|
self, *, action_space: gym.Space, n_envs: int, config: Any
|
||||||
|
) -> None:
|
||||||
|
"""Configure the solver with environment specifications."""
|
||||||
|
self._action_space = action_space
|
||||||
|
self._n_envs = n_envs
|
||||||
|
self._config = config
|
||||||
|
self._action_dim = int(np.prod(action_space.shape[1:]))
|
||||||
|
self._configured = True
|
||||||
|
|
||||||
|
if not isinstance(action_space, Box):
|
||||||
|
logging.warning(
|
||||||
|
f'Action space is discrete, got {type(action_space)}. GradientSolver may not work as expected.'
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def n_envs(self) -> int:
|
||||||
|
"""Number of parallel environments."""
|
||||||
|
return self._n_envs
|
||||||
|
|
||||||
|
@property
|
||||||
|
def action_dim(self) -> int:
|
||||||
|
"""Flattened action dimension including action_block grouping."""
|
||||||
|
return self._action_dim * self._config.action_block
|
||||||
|
|
||||||
|
@property
|
||||||
|
def horizon(self) -> int:
|
||||||
|
"""Planning horizon in timesteps."""
|
||||||
|
return self._config.horizon
|
||||||
|
|
||||||
|
def __call__(self, *args: Any, **kwargs: Any) -> dict:
|
||||||
|
"""Make solver callable, forwarding to solve()."""
|
||||||
|
return self.solve(*args, **kwargs)
|
||||||
|
|
||||||
|
def init_action(self, actions: torch.Tensor | None = None) -> None:
|
||||||
|
"""Initialize the action tensor for optimization."""
|
||||||
|
device = torch.device(self.device)
|
||||||
|
if actions is None:
|
||||||
|
actions = torch.zeros(
|
||||||
|
(self._n_envs, 0, self.action_dim), device=device
|
||||||
|
)
|
||||||
|
elif actions.device != device:
|
||||||
|
actions = actions.to(device, non_blocking=True)
|
||||||
|
|
||||||
|
# fill remaining action
|
||||||
|
remaining = self.horizon - actions.shape[1]
|
||||||
|
|
||||||
|
if remaining > 0:
|
||||||
|
new_actions = torch.zeros(
|
||||||
|
self._n_envs,
|
||||||
|
remaining,
|
||||||
|
self.action_dim,
|
||||||
|
device=actions.device,
|
||||||
|
dtype=actions.dtype,
|
||||||
|
)
|
||||||
|
actions = torch.cat([actions, new_actions], dim=1)
|
||||||
|
|
||||||
|
actions = actions.unsqueeze(1).repeat_interleave(
|
||||||
|
self.num_samples, dim=1
|
||||||
|
) # add sample dim
|
||||||
|
actions[:, 1:] += (
|
||||||
|
torch.randn(
|
||||||
|
actions[:, 1:].shape,
|
||||||
|
generator=self.torch_gen,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
* self.var_scale
|
||||||
|
) # add small noise to all samples except the first one
|
||||||
|
|
||||||
|
# reset actions
|
||||||
|
if hasattr(self, 'init'):
|
||||||
|
self.init.copy_(actions)
|
||||||
|
else:
|
||||||
|
self.register_parameter('init', torch.nn.Parameter(actions))
|
||||||
|
|
||||||
|
def solve(
|
||||||
|
self, info_dict: dict, init_action: torch.Tensor | None = None
|
||||||
|
) -> dict:
|
||||||
|
"""Solve the planning problem using gradient descent."""
|
||||||
|
start_time = time.time()
|
||||||
|
outputs = {
|
||||||
|
'cost': [], # Will store list of cost histories per batch
|
||||||
|
'actions': None,
|
||||||
|
}
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
self.init_action(init_action)
|
||||||
|
|
||||||
|
# Determine batch size (default to all envs if not specified which can cause memory issues)
|
||||||
|
batch_size = (
|
||||||
|
self.batch_size if self.batch_size is not None else self.n_envs
|
||||||
|
)
|
||||||
|
total_envs = self.n_envs
|
||||||
|
|
||||||
|
# Lists to hold results from each batch to be concatenated later
|
||||||
|
batch_top_actions_list = []
|
||||||
|
|
||||||
|
# --- Outer Loop: Iterate over batches ---
|
||||||
|
for start_idx in range(0, total_envs, batch_size):
|
||||||
|
end_idx = min(start_idx + batch_size, total_envs)
|
||||||
|
current_bs = end_idx - start_idx
|
||||||
|
|
||||||
|
batch_init = self.init[start_idx:end_idx].clone().detach()
|
||||||
|
batch_init.requires_grad = True
|
||||||
|
|
||||||
|
# We initialize the optimizer class passed in __init__ with the kwargs
|
||||||
|
optim = self.optimizer_cls([batch_init], **self.optimizer_kwargs)
|
||||||
|
|
||||||
|
# Prepare Batch Infos
|
||||||
|
# Slice the input info_dict and then expand dimensions
|
||||||
|
expanded_infos = {}
|
||||||
|
for k, v in info_dict.items():
|
||||||
|
# Slice the data for the current batch indices
|
||||||
|
# Assumes input data dim 0 corresponds to n_envs
|
||||||
|
if torch.is_tensor(v):
|
||||||
|
batch_v = v[start_idx:end_idx]
|
||||||
|
batch_v = batch_v.unsqueeze(1)
|
||||||
|
batch_v = batch_v.expand(
|
||||||
|
current_bs, self.num_samples, *batch_v.shape[2:]
|
||||||
|
)
|
||||||
|
elif isinstance(v, np.ndarray):
|
||||||
|
batch_v = v[start_idx:end_idx]
|
||||||
|
batch_v = np.repeat(
|
||||||
|
batch_v[:, None, ...], self.num_samples, axis=1
|
||||||
|
)
|
||||||
|
expanded_infos[k] = batch_v
|
||||||
|
|
||||||
|
final_batch_cost = None
|
||||||
|
|
||||||
|
for step in range(self.n_steps):
|
||||||
|
current_info = expanded_infos.copy()
|
||||||
|
|
||||||
|
# Calculate cost using the batch parameter
|
||||||
|
costs = self.model.get_cost(current_info, batch_init)
|
||||||
|
|
||||||
|
assert isinstance(costs, torch.Tensor), (
|
||||||
|
f'Got {type(costs)} cost, expect torch.Tensor'
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
costs.ndim == 2
|
||||||
|
and costs.shape[0] == current_bs
|
||||||
|
and costs.shape[1] == self.num_samples
|
||||||
|
), (
|
||||||
|
f'Cost should be of shape ({current_bs}, {self.num_samples}), got {costs.shape}'
|
||||||
|
)
|
||||||
|
assert costs.requires_grad, (
|
||||||
|
'Cost must requires_grad for GD solver.'
|
||||||
|
)
|
||||||
|
|
||||||
|
cost = costs.sum() # Sum cost for this batch
|
||||||
|
cost.backward()
|
||||||
|
optim.step()
|
||||||
|
optim.zero_grad(set_to_none=True)
|
||||||
|
|
||||||
|
# Add noise
|
||||||
|
if self.action_noise > 0:
|
||||||
|
batch_init.data += (
|
||||||
|
torch.randn(
|
||||||
|
batch_init.shape,
|
||||||
|
generator=self.torch_gen,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
* self.action_noise
|
||||||
|
)
|
||||||
|
|
||||||
|
final_batch_cost = costs.detach().min(dim=1).values
|
||||||
|
|
||||||
|
# Store cost history for this batch
|
||||||
|
outputs['cost'].append(final_batch_cost)
|
||||||
|
|
||||||
|
# Update the global self.init with the optimized batch values
|
||||||
|
with torch.no_grad():
|
||||||
|
self.init[start_idx:end_idx] = batch_init
|
||||||
|
|
||||||
|
top_idx = costs.argmin(dim=1)
|
||||||
|
batch_indices = torch.arange(current_bs, device=self.device)
|
||||||
|
|
||||||
|
top_actions_batch = batch_init[batch_indices, top_idx]
|
||||||
|
batch_top_actions_list.append(top_actions_batch.detach())
|
||||||
|
|
||||||
|
# Concatenate all batch results
|
||||||
|
outputs['actions'] = torch.cat(batch_top_actions_list, dim=0)
|
||||||
|
outputs['cost'] = torch.cat(outputs['cost']).cpu().tolist()
|
||||||
|
end_time = time.time()
|
||||||
|
print(
|
||||||
|
f'GradientSolver.solve completed in {end_time - start_time:.4f} seconds.'
|
||||||
|
)
|
||||||
|
|
||||||
|
return outputs
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
defaults:
|
defaults:
|
||||||
- launcher: local
|
- launcher: local
|
||||||
- solver: cem
|
- solver: gradient
|
||||||
- _self_
|
- _self_
|
||||||
|
|
||||||
world:
|
world:
|
||||||
|
|||||||
14
config/eval/solver/gradient.yaml
Normal file
14
config/eval/solver/gradient.yaml
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
_target_: stable_worldmodel.solver.GradientSolver
|
||||||
|
model: ???
|
||||||
|
# Original adam.yaml reference: n_steps=30, num_samples=100, batch_size=1, lr=0.1.
|
||||||
|
n_steps: 125
|
||||||
|
batch_size: 50
|
||||||
|
num_samples: 1
|
||||||
|
action_noise: 0
|
||||||
|
device: "cuda"
|
||||||
|
seed: ${seed}
|
||||||
|
optimizer_cls:
|
||||||
|
_target_: hydra.utils.get_class
|
||||||
|
path: torch.optim.AdamW
|
||||||
|
optimizer_kwargs:
|
||||||
|
lr: 0.05
|
||||||
9
eval.py
9
eval.py
@@ -153,6 +153,12 @@ def get_inference_context(cfg, device):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_eval_grad_context(solver=None):
|
||||||
|
if isinstance(solver, swm.solver.GradientSolver):
|
||||||
|
return torch.enable_grad()
|
||||||
|
return torch.inference_mode()
|
||||||
|
|
||||||
|
|
||||||
def make_profiler(cfg, results_path):
|
def make_profiler(cfg, results_path):
|
||||||
profile_cfg = get_profile_cfg(cfg)
|
profile_cfg = get_profile_cfg(cfg)
|
||||||
if not profile_cfg["enabled"]:
|
if not profile_cfg["enabled"]:
|
||||||
@@ -345,6 +351,7 @@ def run_eval_subset(
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
policy = swm.policy.RandomPolicy()
|
policy = swm.policy.RandomPolicy()
|
||||||
|
solver = None
|
||||||
inference_ctx = nullcontext()
|
inference_ctx = nullcontext()
|
||||||
inference_precision = "fp32"
|
inference_precision = "fp32"
|
||||||
compile_cfg = get_compile_cfg(local_cfg)
|
compile_cfg = get_compile_cfg(local_cfg)
|
||||||
@@ -357,7 +364,7 @@ def run_eval_subset(
|
|||||||
if str(device).startswith("cuda") and torch.cuda.is_available():
|
if str(device).startswith("cuda") and torch.cuda.is_available():
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
with torch.inference_mode():
|
with get_eval_grad_context(solver):
|
||||||
with profiler_ctx as profiler:
|
with profiler_ctx as profiler:
|
||||||
with inference_ctx:
|
with inference_ctx:
|
||||||
with torch.profiler.record_function("eval.world_evaluate_from_dataset"):
|
with torch.profiler.record_function("eval.world_evaluate_from_dataset"):
|
||||||
|
|||||||
21
jepa.py
21
jepa.py
@@ -189,6 +189,27 @@ class JEPA(nn.Module):
|
|||||||
|
|
||||||
pred_rollout = self.predict(emb_hist, act_emb_hist)[:, -1:]
|
pred_rollout = self.predict(emb_hist, act_emb_hist)[:, -1:]
|
||||||
else:
|
else:
|
||||||
|
if torch.is_grad_enabled() and action_sequence.requires_grad:
|
||||||
|
emb_slots = init_hist.split(1, dim=1)
|
||||||
|
act_slots = act_hist.split(1, dim=1)
|
||||||
|
|
||||||
|
for t in range(act_future.size(1)):
|
||||||
|
emb_view = torch.cat(emb_slots[-HS:], dim=1)
|
||||||
|
act_view = torch.cat(act_slots[-HS:], dim=1)
|
||||||
|
pred_emb = self.predict(emb_view, act_view)[:, -1:]
|
||||||
|
next_act_emb = act_future[:, t : t + 1]
|
||||||
|
|
||||||
|
emb_slots = (*emb_slots[-(HS - 1) :], pred_emb)
|
||||||
|
act_slots = (*act_slots[-(HS - 1) :], next_act_emb)
|
||||||
|
|
||||||
|
emb_view = torch.cat(emb_slots[-HS:], dim=1)
|
||||||
|
act_view = torch.cat(act_slots[-HS:], dim=1)
|
||||||
|
pred_rollout = self.predict(emb_view, act_view)[:, -1:]
|
||||||
|
info["predicted_emb"] = pred_rollout.reshape(
|
||||||
|
B, S, *pred_rollout.shape[1:]
|
||||||
|
)
|
||||||
|
return info
|
||||||
|
|
||||||
emb_hist = init_hist.new_empty((B * S, HS, init_hist.size(-1)))
|
emb_hist = init_hist.new_empty((B * S, HS, init_hist.size(-1)))
|
||||||
act_emb_hist = action_emb.new_empty((B * S, HS, action_emb.size(-1)))
|
act_emb_hist = action_emb.new_empty((B * S, HS, action_emb.size(-1)))
|
||||||
emb_hist[:, :hist_len].copy_(init_hist)
|
emb_hist[:, :hist_len].copy_(init_hist)
|
||||||
|
|||||||
@@ -1228,3 +1228,265 @@ evaluation_time: 16.38284730911255 seconds
|
|||||||
inference_precision: fp16
|
inference_precision: fp16
|
||||||
inference_compile_target: predictor
|
inference_compile_target: predictor
|
||||||
inference_compile_mode: reduce-overhead
|
inference_compile_mode: reduce-overhead
|
||||||
|
|
||||||
|
==== CONFIG ====
|
||||||
|
cache_dir: null
|
||||||
|
solver:
|
||||||
|
_target_: stable_worldmodel.solver.CEMSolver
|
||||||
|
model: ???
|
||||||
|
batch_size: 8
|
||||||
|
num_samples: 64
|
||||||
|
var_scale: 1.0
|
||||||
|
n_steps: 10
|
||||||
|
topk: 8
|
||||||
|
device: cuda
|
||||||
|
seed: ${seed}
|
||||||
|
world:
|
||||||
|
env_name: swm/PushT-v1
|
||||||
|
num_envs: ${eval.num_eval}
|
||||||
|
max_episode_steps: ???
|
||||||
|
history_size: 1
|
||||||
|
frame_skip: 1
|
||||||
|
dataset:
|
||||||
|
stats: ${eval.dataset_name}
|
||||||
|
keys_to_cache:
|
||||||
|
- action
|
||||||
|
- proprio
|
||||||
|
- state
|
||||||
|
seed: 42
|
||||||
|
policy: pusht/lewm
|
||||||
|
inference_precision: fp16
|
||||||
|
plan_config:
|
||||||
|
horizon: 5
|
||||||
|
receding_horizon: 5
|
||||||
|
action_block: 5
|
||||||
|
eval:
|
||||||
|
num_eval: 50
|
||||||
|
goal_offset_steps: 25
|
||||||
|
eval_budget: 50
|
||||||
|
img_size: 224
|
||||||
|
save_video: false
|
||||||
|
dataset_name: pusht_expert_train
|
||||||
|
callables:
|
||||||
|
- method: _set_state
|
||||||
|
args:
|
||||||
|
state:
|
||||||
|
value: state
|
||||||
|
- method: _set_goal_state
|
||||||
|
args:
|
||||||
|
goal_state:
|
||||||
|
value: goal_state
|
||||||
|
output:
|
||||||
|
filename: pusht_results.txt
|
||||||
|
|
||||||
|
==== RESULTS ====
|
||||||
|
metrics: {'success_rate': 90.0, 'episode_successes': array([ True, True, True, True, True, True, True, True, True,
|
||||||
|
True, True, True, True, True, False, False, True, True,
|
||||||
|
True, True, True, True, True, True, True, True, True,
|
||||||
|
True, True, True, True, True, True, True, True, False,
|
||||||
|
True, True, True, True, True, True, False, True, True,
|
||||||
|
True, True, True, False, True]), 'seeds': None}
|
||||||
|
evaluation_time: 16.081845998764038 seconds
|
||||||
|
inference_precision: fp16
|
||||||
|
inference_compile_target: predictor
|
||||||
|
inference_compile_mode: reduce-overhead
|
||||||
|
|
||||||
|
==== CONFIG ====
|
||||||
|
cache_dir: null
|
||||||
|
solver:
|
||||||
|
_target_: stable_worldmodel.solver.GradientSolver
|
||||||
|
model: ???
|
||||||
|
n_steps: 10
|
||||||
|
batch_size: 8
|
||||||
|
num_samples: 1
|
||||||
|
action_noise: 0
|
||||||
|
device: cuda
|
||||||
|
seed: ${seed}
|
||||||
|
optimizer_cls:
|
||||||
|
_target_: hydra.utils.get_class
|
||||||
|
path: torch.optim.AdamW
|
||||||
|
optimizer_kwargs:
|
||||||
|
lr: 0.05
|
||||||
|
world:
|
||||||
|
env_name: swm/PushT-v1
|
||||||
|
num_envs: ${eval.num_eval}
|
||||||
|
max_episode_steps: ???
|
||||||
|
history_size: 1
|
||||||
|
frame_skip: 1
|
||||||
|
dataset:
|
||||||
|
stats: ${eval.dataset_name}
|
||||||
|
keys_to_cache:
|
||||||
|
- action
|
||||||
|
- proprio
|
||||||
|
- state
|
||||||
|
seed: 42
|
||||||
|
policy: pusht/lewm
|
||||||
|
inference_precision: fp16
|
||||||
|
plan_config:
|
||||||
|
horizon: 5
|
||||||
|
receding_horizon: 5
|
||||||
|
action_block: 5
|
||||||
|
eval:
|
||||||
|
num_eval: 50
|
||||||
|
goal_offset_steps: 25
|
||||||
|
eval_budget: 50
|
||||||
|
img_size: 224
|
||||||
|
save_video: false
|
||||||
|
dataset_name: pusht_expert_train
|
||||||
|
callables:
|
||||||
|
- method: _set_state
|
||||||
|
args:
|
||||||
|
state:
|
||||||
|
value: state
|
||||||
|
- method: _set_goal_state
|
||||||
|
args:
|
||||||
|
goal_state:
|
||||||
|
value: goal_state
|
||||||
|
output:
|
||||||
|
filename: pusht_results.txt
|
||||||
|
|
||||||
|
==== RESULTS ====
|
||||||
|
metrics: {'success_rate': 46.0, 'episode_successes': array([False, False, True, False, True, True, True, False, False,
|
||||||
|
True, False, False, True, False, False, False, False, False,
|
||||||
|
True, True, False, True, False, True, True, False, True,
|
||||||
|
False, True, True, True, False, False, True, False, False,
|
||||||
|
True, True, True, False, False, False, False, True, True,
|
||||||
|
True, True, False, False, False]), 'seeds': None}
|
||||||
|
evaluation_time: 63.84614443778992 seconds
|
||||||
|
inference_precision: fp16
|
||||||
|
inference_compile_target: predictor
|
||||||
|
inference_compile_mode: reduce-overhead
|
||||||
|
|
||||||
|
==== CONFIG ====
|
||||||
|
cache_dir: null
|
||||||
|
solver:
|
||||||
|
_target_: stable_worldmodel.solver.GradientSolver
|
||||||
|
model: ???
|
||||||
|
n_steps: 125
|
||||||
|
batch_size: 50
|
||||||
|
num_samples: 1
|
||||||
|
action_noise: 0
|
||||||
|
device: cuda
|
||||||
|
seed: ${seed}
|
||||||
|
optimizer_cls:
|
||||||
|
_target_: hydra.utils.get_class
|
||||||
|
path: torch.optim.AdamW
|
||||||
|
optimizer_kwargs:
|
||||||
|
lr: 0.05
|
||||||
|
world:
|
||||||
|
env_name: swm/PushT-v1
|
||||||
|
num_envs: ${eval.num_eval}
|
||||||
|
max_episode_steps: ???
|
||||||
|
history_size: 1
|
||||||
|
frame_skip: 1
|
||||||
|
dataset:
|
||||||
|
stats: ${eval.dataset_name}
|
||||||
|
keys_to_cache:
|
||||||
|
- action
|
||||||
|
- proprio
|
||||||
|
- state
|
||||||
|
seed: 42
|
||||||
|
policy: pusht/lewm
|
||||||
|
inference_precision: fp16
|
||||||
|
plan_config:
|
||||||
|
horizon: 5
|
||||||
|
receding_horizon: 5
|
||||||
|
action_block: 5
|
||||||
|
eval:
|
||||||
|
num_eval: 50
|
||||||
|
goal_offset_steps: 25
|
||||||
|
eval_budget: 50
|
||||||
|
img_size: 224
|
||||||
|
save_video: false
|
||||||
|
dataset_name: pusht_expert_train
|
||||||
|
callables:
|
||||||
|
- method: _set_state
|
||||||
|
args:
|
||||||
|
state:
|
||||||
|
value: state
|
||||||
|
- method: _set_goal_state
|
||||||
|
args:
|
||||||
|
goal_state:
|
||||||
|
value: goal_state
|
||||||
|
output:
|
||||||
|
filename: pusht_results.txt
|
||||||
|
profile:
|
||||||
|
enabled: false
|
||||||
|
|
||||||
|
==== RESULTS ====
|
||||||
|
metrics: {'success_rate': 90.0, 'episode_successes': array([ True, False, True, True, True, True, True, True, True,
|
||||||
|
True, False, True, True, True, True, True, False, True,
|
||||||
|
True, True, True, True, True, True, True, True, True,
|
||||||
|
True, True, True, True, True, True, True, True, True,
|
||||||
|
True, True, True, False, True, True, False, True, True,
|
||||||
|
True, True, True, True, True]), 'seeds': None}
|
||||||
|
evaluation_time: 15.638921022415161 seconds
|
||||||
|
inference_precision: fp16
|
||||||
|
inference_compile_target: predictor
|
||||||
|
inference_compile_mode: reduce-overhead
|
||||||
|
|
||||||
|
==== CONFIG ====
|
||||||
|
cache_dir: null
|
||||||
|
solver:
|
||||||
|
_target_: stable_worldmodel.solver.GradientSolver
|
||||||
|
model: ???
|
||||||
|
n_steps: 125
|
||||||
|
batch_size: 50
|
||||||
|
num_samples: 1
|
||||||
|
action_noise: 0
|
||||||
|
device: cuda
|
||||||
|
seed: ${seed}
|
||||||
|
optimizer_cls:
|
||||||
|
_target_: hydra.utils.get_class
|
||||||
|
path: torch.optim.AdamW
|
||||||
|
optimizer_kwargs:
|
||||||
|
lr: 0.05
|
||||||
|
world:
|
||||||
|
env_name: swm/PushT-v1
|
||||||
|
num_envs: ${eval.num_eval}
|
||||||
|
max_episode_steps: ???
|
||||||
|
history_size: 1
|
||||||
|
frame_skip: 1
|
||||||
|
dataset:
|
||||||
|
stats: ${eval.dataset_name}
|
||||||
|
keys_to_cache:
|
||||||
|
- action
|
||||||
|
- proprio
|
||||||
|
- state
|
||||||
|
seed: 42
|
||||||
|
policy: pusht/lewm
|
||||||
|
inference_precision: fp16
|
||||||
|
plan_config:
|
||||||
|
horizon: 5
|
||||||
|
receding_horizon: 5
|
||||||
|
action_block: 5
|
||||||
|
eval:
|
||||||
|
num_eval: 50
|
||||||
|
goal_offset_steps: 25
|
||||||
|
eval_budget: 50
|
||||||
|
img_size: 224
|
||||||
|
save_video: false
|
||||||
|
dataset_name: pusht_expert_train
|
||||||
|
callables:
|
||||||
|
- method: _set_state
|
||||||
|
args:
|
||||||
|
state:
|
||||||
|
value: state
|
||||||
|
- method: _set_goal_state
|
||||||
|
args:
|
||||||
|
goal_state:
|
||||||
|
value: goal_state
|
||||||
|
output:
|
||||||
|
filename: pusht_results.txt
|
||||||
|
|
||||||
|
==== RESULTS ====
|
||||||
|
metrics: {'success_rate': 90.0, 'episode_successes': array([ True, False, True, True, True, True, True, True, True,
|
||||||
|
True, False, True, True, True, True, True, False, True,
|
||||||
|
True, True, True, True, True, True, True, True, True,
|
||||||
|
True, True, True, True, True, True, True, True, True,
|
||||||
|
True, True, True, False, True, True, False, True, True,
|
||||||
|
True, True, True, True, True]), 'seeds': None}
|
||||||
|
evaluation_time: 16.060093879699707 seconds
|
||||||
|
inference_precision: fp16
|
||||||
|
inference_compile_target: predictor
|
||||||
|
inference_compile_mode: reduce-overhead
|
||||||
|
|||||||
Reference in New Issue
Block a user