Initial commit
This commit is contained in:
61
config/eval/cube.yaml
Normal file
61
config/eval/cube.yaml
Normal file
@@ -0,0 +1,61 @@
|
||||
defaults:
|
||||
- launcher: local
|
||||
- solver: cem
|
||||
- _self_
|
||||
|
||||
world:
|
||||
env_name: swm/OGBCube-v0
|
||||
num_envs: ${eval.num_eval}
|
||||
max_episode_steps: ??? # make sure it's >= eval_budget
|
||||
history_size: 1
|
||||
frame_skip: 1
|
||||
env_type: single
|
||||
ob_type: states
|
||||
multiview: False
|
||||
width: 224
|
||||
height: 224
|
||||
visualize_info: False
|
||||
terminate_at_goal: True
|
||||
|
||||
dataset:
|
||||
stats: ${eval.dataset_name}
|
||||
keys_to_cache:
|
||||
- action
|
||||
|
||||
seed: 42
|
||||
policy: random # ckpt name or random
|
||||
|
||||
plan_config:
|
||||
horizon: 5
|
||||
receding_horizon: 5
|
||||
action_block: 5 # frameskip
|
||||
|
||||
# evaluation from dataset (replay expert trajectories)
|
||||
eval:
|
||||
num_eval: 50
|
||||
goal_offset_steps: 25
|
||||
eval_budget: 50
|
||||
img_size: 224
|
||||
dataset_name: ogbench/cube_single_expert
|
||||
callables:
|
||||
# -- set state
|
||||
- method: set_state
|
||||
args:
|
||||
qpos:
|
||||
value: qpos
|
||||
qvel:
|
||||
value: qvel
|
||||
# -- set target pos
|
||||
- method: set_target_pos
|
||||
args:
|
||||
cube_id:
|
||||
value: 0
|
||||
in_dataset: False
|
||||
target_pos:
|
||||
value: goal_privileged_block_0_pos
|
||||
target_quat:
|
||||
value: goal_privileged_block_0_quat
|
||||
|
||||
output:
|
||||
filename: ogb_cube_results.txt
|
||||
|
||||
7
config/eval/launcher/local.yaml
Normal file
7
config/eval/launcher/local.yaml
Normal file
@@ -0,0 +1,7 @@
|
||||
# @package _global_
|
||||
# Local launcher configuration (no SLURM)
|
||||
|
||||
defaults:
|
||||
- override /hydra/launcher: basic
|
||||
|
||||
cache_dir: null # use stable-worldmodel default cache
|
||||
48
config/eval/pusht.yaml
Normal file
48
config/eval/pusht.yaml
Normal file
@@ -0,0 +1,48 @@
|
||||
defaults:
|
||||
- launcher: local
|
||||
- solver: cem
|
||||
- _self_
|
||||
|
||||
world:
|
||||
env_name: swm/PushT-v1
|
||||
num_envs: ${eval.num_eval}
|
||||
max_episode_steps: ??? # make sure it's >= eval_budget
|
||||
history_size: 1
|
||||
frame_skip: 1
|
||||
|
||||
dataset:
|
||||
stats: ${eval.dataset_name}
|
||||
keys_to_cache:
|
||||
- action
|
||||
- proprio
|
||||
- state
|
||||
|
||||
seed: 42
|
||||
policy: random # ckpt name or random
|
||||
|
||||
plan_config:
|
||||
horizon: 5
|
||||
receding_horizon: 5
|
||||
action_block: 5 # frameskip
|
||||
|
||||
# evaluation from dataset (replay expert trajectories)
|
||||
eval:
|
||||
num_eval: 50
|
||||
goal_offset_steps: 25
|
||||
eval_budget: 50
|
||||
img_size: 224
|
||||
dataset_name: pusht_expert_train
|
||||
callables:
|
||||
# -- set state
|
||||
- method: _set_state
|
||||
args:
|
||||
state:
|
||||
value: state
|
||||
# -- set goal state
|
||||
- method: _set_goal_state
|
||||
args:
|
||||
goal_state:
|
||||
value: goal_state
|
||||
|
||||
output:
|
||||
filename: pusht_results.txt
|
||||
50
config/eval/reacher.yaml
Normal file
50
config/eval/reacher.yaml
Normal file
@@ -0,0 +1,50 @@
|
||||
defaults:
|
||||
- launcher: local
|
||||
- solver: cem
|
||||
- _self_
|
||||
|
||||
world:
|
||||
env_name: swm/ReacherDMControl-v0
|
||||
num_envs: ${eval.num_eval}
|
||||
max_episode_steps: ??? # make sure it's >= eval_budget
|
||||
history_size: 1
|
||||
frame_skip: 1
|
||||
task: qpos_match
|
||||
|
||||
dataset:
|
||||
stats: ${eval.dataset_name}
|
||||
keys_to_cache:
|
||||
- action
|
||||
|
||||
seed: 42
|
||||
policy: random # ckpt name or random
|
||||
|
||||
plan_config:
|
||||
horizon: 5
|
||||
receding_horizon: 5
|
||||
action_block: 5 # frameskip
|
||||
|
||||
# evaluation from dataset (replay expert trajectories)
|
||||
eval:
|
||||
num_eval: 50
|
||||
goal_offset_steps: 25
|
||||
eval_budget: 50
|
||||
img_size: 224
|
||||
dataset_name: dmc/reacher_random
|
||||
callables:
|
||||
# -- set state
|
||||
- method: set_state
|
||||
args:
|
||||
qpos:
|
||||
value: qpos
|
||||
qvel:
|
||||
value: qvel
|
||||
|
||||
- method: set_target_qpos
|
||||
args:
|
||||
target_qpos:
|
||||
value: goal_qpos
|
||||
|
||||
output:
|
||||
filename: dmc_results.txt
|
||||
|
||||
13
config/eval/solver/adam.yaml
Normal file
13
config/eval/solver/adam.yaml
Normal file
@@ -0,0 +1,13 @@
|
||||
_target_: stable_worldmodel.solver.GradientSolver
|
||||
model: ???
|
||||
n_steps: 30
|
||||
batch_size: 1
|
||||
num_samples: 100
|
||||
action_noise: 0
|
||||
device: "cuda"
|
||||
seed: ${seed}
|
||||
optimizer_cls:
|
||||
_target_: hydra.utils.get_class
|
||||
path: torch.optim.AdamW
|
||||
optimizer_kwargs:
|
||||
lr: 0.1
|
||||
9
config/eval/solver/cem.yaml
Normal file
9
config/eval/solver/cem.yaml
Normal file
@@ -0,0 +1,9 @@
|
||||
_target_: stable_worldmodel.solver.CEMSolver
|
||||
model: ???
|
||||
batch_size: 1
|
||||
num_samples: 300
|
||||
var_scale: 1.0
|
||||
n_steps: 30
|
||||
topk: 30
|
||||
device: "cuda"
|
||||
seed: ${seed}
|
||||
47
config/eval/tworoom.yaml
Normal file
47
config/eval/tworoom.yaml
Normal file
@@ -0,0 +1,47 @@
|
||||
defaults:
|
||||
- launcher: local
|
||||
- solver: cem
|
||||
- _self_
|
||||
|
||||
world:
|
||||
env_name: swm/TwoRoom-v1
|
||||
num_envs: ${eval.num_eval}
|
||||
max_episode_steps: ??? # make sure it's >= eval_budget
|
||||
history_size: 1
|
||||
frame_skip: 1
|
||||
|
||||
seed: 42
|
||||
policy: random # ckpt name or random
|
||||
|
||||
dataset:
|
||||
stats: ${eval.dataset_name}
|
||||
keys_to_cache:
|
||||
- action
|
||||
- proprio
|
||||
|
||||
plan_config:
|
||||
horizon: 5
|
||||
receding_horizon: 5
|
||||
action_block: 5 # frameskip
|
||||
|
||||
# evaluation from dataset (replay expert trajectories)
|
||||
eval:
|
||||
num_eval: 50
|
||||
goal_offset_steps: 25
|
||||
eval_budget: 50
|
||||
img_size: 224
|
||||
dataset_name: tworoom
|
||||
callables:
|
||||
# -- set state
|
||||
- method: _set_state
|
||||
args:
|
||||
state:
|
||||
value: proprio
|
||||
# -- set goal state
|
||||
- method: _set_goal_state
|
||||
args:
|
||||
goal_state:
|
||||
value: goal_proprio
|
||||
|
||||
output:
|
||||
filename: tworoom_results.txt
|
||||
11
config/train/data/dmc.yaml
Normal file
11
config/train/data/dmc.yaml
Normal file
@@ -0,0 +1,11 @@
|
||||
dataset:
|
||||
num_steps: ${eval:'${wm.num_preds} + ${wm.history_size}'}
|
||||
frameskip: 5
|
||||
name: reacher
|
||||
keys_to_load:
|
||||
- pixels
|
||||
- action
|
||||
- observation
|
||||
keys_to_cache:
|
||||
- action
|
||||
- observation
|
||||
13
config/train/data/ogb.yaml
Normal file
13
config/train/data/ogb.yaml
Normal file
@@ -0,0 +1,13 @@
|
||||
dataset:
|
||||
name: ogbench/cube_single_expert
|
||||
num_steps: ${eval:'${wm.num_preds} + ${wm.history_size}'}
|
||||
frameskip: 5
|
||||
keys_to_load:
|
||||
- pixels
|
||||
- action
|
||||
- observation
|
||||
keys_to_cache:
|
||||
- action
|
||||
- observation
|
||||
keys_to_merge:
|
||||
proprio: proprio
|
||||
13
config/train/data/pusht.yaml
Normal file
13
config/train/data/pusht.yaml
Normal file
@@ -0,0 +1,13 @@
|
||||
dataset:
|
||||
num_steps: ${eval:'${wm.num_preds} + ${wm.history_size}'}
|
||||
frameskip: 5
|
||||
name: pusht_expert_train
|
||||
keys_to_load:
|
||||
- pixels
|
||||
- action
|
||||
- proprio
|
||||
- state
|
||||
keys_to_cache:
|
||||
- action
|
||||
- proprio
|
||||
- state
|
||||
11
config/train/data/tworoom.yaml
Normal file
11
config/train/data/tworoom.yaml
Normal file
@@ -0,0 +1,11 @@
|
||||
dataset:
|
||||
num_steps: ${eval:'${wm.num_preds} + ${wm.history_size}'}
|
||||
frameskip: 5
|
||||
name: tworoom
|
||||
keys_to_load:
|
||||
- pixels
|
||||
- action
|
||||
- proprio
|
||||
keys_to_cache:
|
||||
- action
|
||||
- proprio
|
||||
11
config/train/launcher/local.yaml
Normal file
11
config/train/launcher/local.yaml
Normal file
@@ -0,0 +1,11 @@
|
||||
# @package _global_
|
||||
# Local launcher configuration (no SLURM)
|
||||
|
||||
defaults:
|
||||
- override /hydra/launcher: basic
|
||||
|
||||
wandb:
|
||||
enabled: True
|
||||
config:
|
||||
project: le-wm
|
||||
entity: le-wm
|
||||
64
config/train/lewm.yaml
Normal file
64
config/train/lewm.yaml
Normal file
@@ -0,0 +1,64 @@
|
||||
defaults:
|
||||
- _self_
|
||||
- data: pusht
|
||||
|
||||
output_model_name: lewm
|
||||
subdir: ${hydra:job.id}
|
||||
|
||||
num_workers: 6
|
||||
train_split: 0.9
|
||||
seed: 3072
|
||||
img_size: 224
|
||||
patch_size: 14
|
||||
encoder_scale: tiny
|
||||
dump_object: True
|
||||
|
||||
trainer:
|
||||
max_epochs: 100
|
||||
devices: auto
|
||||
accelerator: gpu
|
||||
precision: bf16
|
||||
gradient_clip_val: 1.0
|
||||
|
||||
loader:
|
||||
batch_size: 128
|
||||
num_workers: ${num_workers}
|
||||
persistent_workers: True
|
||||
prefetch_factor: 3
|
||||
pin_memory: True
|
||||
|
||||
optimizer:
|
||||
type: AdamW
|
||||
lr: 5e-5
|
||||
weight_decay: 1e-3
|
||||
|
||||
wandb:
|
||||
enabled: True
|
||||
config:
|
||||
entity: lewm
|
||||
project: lewm
|
||||
name: ${output_model_name}
|
||||
id: ${subdir}
|
||||
resume: allow
|
||||
log_model: False
|
||||
|
||||
wm:
|
||||
type: lewm
|
||||
history_size: 3
|
||||
num_preds: 1
|
||||
embed_dim: 192
|
||||
|
||||
predictor:
|
||||
depth: 6
|
||||
heads: 16
|
||||
mlp_dim: 2048
|
||||
dim_head: 64
|
||||
dropout: 0.1
|
||||
emb_dropout: 0.0
|
||||
|
||||
loss:
|
||||
sigreg:
|
||||
weight: 0.09
|
||||
kwargs:
|
||||
knots: 17
|
||||
num_proj: 1024
|
||||
Reference in New Issue
Block a user