Initial commit

This commit is contained in:
Lucas Maes
2026-03-12 22:56:21 -04:00
committed by lucas-maes
commit 83f97d72ad
21 changed files with 1355 additions and 0 deletions

61
config/eval/cube.yaml Normal file
View File

@@ -0,0 +1,61 @@
defaults:
- launcher: local
- solver: cem
- _self_
world:
env_name: swm/OGBCube-v0
num_envs: ${eval.num_eval}
max_episode_steps: ??? # make sure it's >= eval_budget
history_size: 1
frame_skip: 1
env_type: single
ob_type: states
multiview: False
width: 224
height: 224
visualize_info: False
terminate_at_goal: True
dataset:
stats: ${eval.dataset_name}
keys_to_cache:
- action
seed: 42
policy: random # ckpt name or random
plan_config:
horizon: 5
receding_horizon: 5
action_block: 5 # frameskip
# evaluation from dataset (replay expert trajectories)
eval:
num_eval: 50
goal_offset_steps: 25
eval_budget: 50
img_size: 224
dataset_name: ogbench/cube_single_expert
callables:
# -- set state
- method: set_state
args:
qpos:
value: qpos
qvel:
value: qvel
# -- set target pos
- method: set_target_pos
args:
cube_id:
value: 0
in_dataset: False
target_pos:
value: goal_privileged_block_0_pos
target_quat:
value: goal_privileged_block_0_quat
output:
filename: ogb_cube_results.txt

View File

@@ -0,0 +1,7 @@
# @package _global_
# Local launcher configuration (no SLURM)
defaults:
- override /hydra/launcher: basic
cache_dir: null # use stable-worldmodel default cache

48
config/eval/pusht.yaml Normal file
View File

@@ -0,0 +1,48 @@
defaults:
- launcher: local
- solver: cem
- _self_
world:
env_name: swm/PushT-v1
num_envs: ${eval.num_eval}
max_episode_steps: ??? # make sure it's >= eval_budget
history_size: 1
frame_skip: 1
dataset:
stats: ${eval.dataset_name}
keys_to_cache:
- action
- proprio
- state
seed: 42
policy: random # ckpt name or random
plan_config:
horizon: 5
receding_horizon: 5
action_block: 5 # frameskip
# evaluation from dataset (replay expert trajectories)
eval:
num_eval: 50
goal_offset_steps: 25
eval_budget: 50
img_size: 224
dataset_name: pusht_expert_train
callables:
# -- set state
- method: _set_state
args:
state:
value: state
# -- set goal state
- method: _set_goal_state
args:
goal_state:
value: goal_state
output:
filename: pusht_results.txt

50
config/eval/reacher.yaml Normal file
View File

@@ -0,0 +1,50 @@
defaults:
- launcher: local
- solver: cem
- _self_
world:
env_name: swm/ReacherDMControl-v0
num_envs: ${eval.num_eval}
max_episode_steps: ??? # make sure it's >= eval_budget
history_size: 1
frame_skip: 1
task: qpos_match
dataset:
stats: ${eval.dataset_name}
keys_to_cache:
- action
seed: 42
policy: random # ckpt name or random
plan_config:
horizon: 5
receding_horizon: 5
action_block: 5 # frameskip
# evaluation from dataset (replay expert trajectories)
eval:
num_eval: 50
goal_offset_steps: 25
eval_budget: 50
img_size: 224
dataset_name: dmc/reacher_random
callables:
# -- set state
- method: set_state
args:
qpos:
value: qpos
qvel:
value: qvel
- method: set_target_qpos
args:
target_qpos:
value: goal_qpos
output:
filename: dmc_results.txt

View File

@@ -0,0 +1,13 @@
_target_: stable_worldmodel.solver.GradientSolver
model: ???
n_steps: 30
batch_size: 1
num_samples: 100
action_noise: 0
device: "cuda"
seed: ${seed}
optimizer_cls:
_target_: hydra.utils.get_class
path: torch.optim.AdamW
optimizer_kwargs:
lr: 0.1

View File

@@ -0,0 +1,9 @@
_target_: stable_worldmodel.solver.CEMSolver
model: ???
batch_size: 1
num_samples: 300
var_scale: 1.0
n_steps: 30
topk: 30
device: "cuda"
seed: ${seed}

47
config/eval/tworoom.yaml Normal file
View File

@@ -0,0 +1,47 @@
defaults:
- launcher: local
- solver: cem
- _self_
world:
env_name: swm/TwoRoom-v1
num_envs: ${eval.num_eval}
max_episode_steps: ??? # make sure it's >= eval_budget
history_size: 1
frame_skip: 1
seed: 42
policy: random # ckpt name or random
dataset:
stats: ${eval.dataset_name}
keys_to_cache:
- action
- proprio
plan_config:
horizon: 5
receding_horizon: 5
action_block: 5 # frameskip
# evaluation from dataset (replay expert trajectories)
eval:
num_eval: 50
goal_offset_steps: 25
eval_budget: 50
img_size: 224
dataset_name: tworoom
callables:
# -- set state
- method: _set_state
args:
state:
value: proprio
# -- set goal state
- method: _set_goal_state
args:
goal_state:
value: goal_proprio
output:
filename: tworoom_results.txt

View File

@@ -0,0 +1,11 @@
dataset:
num_steps: ${eval:'${wm.num_preds} + ${wm.history_size}'}
frameskip: 5
name: reacher
keys_to_load:
- pixels
- action
- observation
keys_to_cache:
- action
- observation

View File

@@ -0,0 +1,13 @@
dataset:
name: ogbench/cube_single_expert
num_steps: ${eval:'${wm.num_preds} + ${wm.history_size}'}
frameskip: 5
keys_to_load:
- pixels
- action
- observation
keys_to_cache:
- action
- observation
keys_to_merge:
proprio: proprio

View File

@@ -0,0 +1,13 @@
dataset:
num_steps: ${eval:'${wm.num_preds} + ${wm.history_size}'}
frameskip: 5
name: pusht_expert_train
keys_to_load:
- pixels
- action
- proprio
- state
keys_to_cache:
- action
- proprio
- state

View File

@@ -0,0 +1,11 @@
dataset:
num_steps: ${eval:'${wm.num_preds} + ${wm.history_size}'}
frameskip: 5
name: tworoom
keys_to_load:
- pixels
- action
- proprio
keys_to_cache:
- action
- proprio

View File

@@ -0,0 +1,11 @@
# @package _global_
# Local launcher configuration (no SLURM)
defaults:
- override /hydra/launcher: basic
wandb:
enabled: True
config:
project: le-wm
entity: le-wm

64
config/train/lewm.yaml Normal file
View File

@@ -0,0 +1,64 @@
defaults:
- _self_
- data: pusht
output_model_name: lewm
subdir: ${hydra:job.id}
num_workers: 6
train_split: 0.9
seed: 3072
img_size: 224
patch_size: 14
encoder_scale: tiny
dump_object: True
trainer:
max_epochs: 100
devices: auto
accelerator: gpu
precision: bf16
gradient_clip_val: 1.0
loader:
batch_size: 128
num_workers: ${num_workers}
persistent_workers: True
prefetch_factor: 3
pin_memory: True
optimizer:
type: AdamW
lr: 5e-5
weight_decay: 1e-3
wandb:
enabled: True
config:
entity: lewm
project: lewm
name: ${output_model_name}
id: ${subdir}
resume: allow
log_model: False
wm:
type: lewm
history_size: 3
num_preds: 1
embed_dim: 192
predictor:
depth: 6
heads: 16
mlp_dim: 2048
dim_head: 64
dropout: 0.1
emb_dropout: 0.0
loss:
sigreg:
weight: 0.09
kwargs:
knots: 17
num_proj: 1024