Initial commit

This commit is contained in:
Lucas Maes
2026-03-12 22:56:21 -04:00
committed by lucas-maes
commit 83f97d72ad
21 changed files with 1355 additions and 0 deletions

171
eval.py Normal file
View File

@@ -0,0 +1,171 @@
import os
os.environ["MUJOCO_GL"] = "egl"
import time
from pathlib import Path
import hydra
import numpy as np
import stable_pretraining as spt
import torch
from omegaconf import DictConfig, OmegaConf
from sklearn import preprocessing
from torchvision.transforms import v2 as transforms
import stable_worldmodel as swm
def img_transform(cfg):
transform = transforms.Compose(
[
transforms.ToImage(),
transforms.ToDtype(torch.float32, scale=True),
transforms.Normalize(**spt.data.dataset_stats.ImageNet),
transforms.Resize(size=cfg.eval.img_size),
]
)
return transform
def get_episodes_length(dataset, episodes):
col_name = "episode_idx" if "episode_idx" in dataset.column_names else "ep_idx"
episode_idx = dataset.get_col_data(col_name)
step_idx = dataset.get_col_data("step_idx")
lengths = []
for ep_id in episodes:
lengths.append(np.max(step_idx[episode_idx == ep_id]) + 1)
return np.array(lengths)
def get_dataset(cfg, dataset_name):
dataset_path = Path(cfg.cache_dir or swm.data.utils.get_cache_dir())
dataset = swm.data.HDF5Dataset(
dataset_name,
keys_to_cache=cfg.dataset.keys_to_cache,
cache_dir=dataset_path,
)
return dataset
@hydra.main(version_base=None, config_path="./config/eval", config_name="pusht")
def run(cfg: DictConfig):
"""Run evaluation of dinowm vs random policy."""
assert (
cfg.plan_config.horizon * cfg.plan_config.action_block <= cfg.eval.eval_budget
), "Planning horizon must be smaller than or equal to eval_budget"
# create world environment
cfg.world.max_episode_steps = 2 * cfg.eval.eval_budget
world = swm.World(**cfg.world, image_shape=(224, 224))
# create the transform
transform = {
"pixels": img_transform(cfg),
"goal": img_transform(cfg),
}
dataset = get_dataset(cfg, cfg.eval.dataset_name)
stats_dataset = dataset # get_dataset(cfg, cfg.dataset.stats)
col_name = "episode_idx" if "episode_idx" in dataset.column_names else "ep_idx"
ep_indices, _ = np.unique(stats_dataset.get_col_data(col_name), return_index=True)
process = {}
for col in cfg.dataset.keys_to_cache:
if col in ["pixels"]:
continue
processor = preprocessing.StandardScaler()
col_data = stats_dataset.get_col_data(col)
col_data = col_data[~np.isnan(col_data).any(axis=1)]
processor.fit(col_data)
process[col] = processor
if col != "action":
process[f"goal_{col}"] = process[col]
# -- run evaluation
policy = cfg.get("policy", "random")
if policy != "random":
model = swm.policy.AutoCostModel(cfg.policy)
model = model.to("cuda")
model = model.eval()
model.requires_grad_(False)
model.interpolate_pos_encoding = True
config = swm.PlanConfig(**cfg.plan_config)
solver = hydra.utils.instantiate(cfg.solver, model=model)
policy = swm.policy.WorldModelPolicy(
solver=solver, config=config, process=process, transform=transform
)
else:
policy = swm.policy.RandomPolicy()
results_path = (
Path(swm.data.utils.get_cache_dir(), cfg.policy).parent
if cfg.policy != "random"
else Path(__file__).parent
)
# sample the episodes and the starting indices
episode_len = get_episodes_length(dataset, ep_indices)
max_start_idx = episode_len - cfg.eval.goal_offset_steps - 1
max_start_idx_dict = {ep_id: max_start_idx[i] for i, ep_id in enumerate(ep_indices)}
# Map each dataset rows episode_idx to its max_start_idx
col_name = "episode_idx" if "episode_idx" in dataset.column_names else "ep_idx"
max_start_per_row = np.array(
[max_start_idx_dict[ep_id] for ep_id in dataset.get_col_data(col_name)]
)
# remove all the lines of dataset for which dataset['step_idx'] > max_start_per_row
valid_mask = dataset.get_col_data("step_idx") <= max_start_per_row
valid_indices = np.nonzero(valid_mask)[0]
print(valid_mask.sum(), "valid starting points found for evaluation.")
g = np.random.default_rng(cfg.seed)
random_episode_indices = g.choice(
len(valid_indices) - 1, size=cfg.eval.num_eval, replace=False
)
# sort increasingly to avoid issues with HDF5Dataset indexing
random_episode_indices = np.sort(valid_indices[random_episode_indices])
print(random_episode_indices)
eval_episodes = dataset.get_row_data(random_episode_indices)[col_name]
eval_start_idx = dataset.get_row_data(random_episode_indices)["step_idx"]
if len(eval_episodes) < cfg.eval.num_eval:
raise ValueError("Not enough episodes with sufficient length for evaluation.")
world.set_policy(policy)
start_time = time.time()
metrics = world.evaluate_from_dataset(
dataset,
start_steps=eval_start_idx.tolist(),
goal_offset_steps=cfg.eval.goal_offset_steps,
eval_budget=cfg.eval.eval_budget,
episodes_idx=eval_episodes.tolist(),
callables=OmegaConf.to_container(cfg.eval.get("callables"), resolve=True),
video_path=results_path,
)
end_time = time.time()
print(metrics)
results_path = results_path / cfg.output.filename
results_path.parent.mkdir(parents=True, exist_ok=True)
with results_path.open("a") as f:
f.write("\n") # separate from previous runs
f.write("==== CONFIG ====\n")
f.write(OmegaConf.to_yaml(cfg))
f.write("\n")
f.write("==== RESULTS ====\n")
f.write(f"metrics: {metrics}\n")
f.write(f"evaluation_time: {end_time - start_time} seconds\n")
if __name__ == "__main__":
run()