Initial commit
This commit is contained in:
171
eval.py
Normal file
171
eval.py
Normal file
@@ -0,0 +1,171 @@
|
||||
import os
|
||||
|
||||
os.environ["MUJOCO_GL"] = "egl"
|
||||
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import hydra
|
||||
import numpy as np
|
||||
import stable_pretraining as spt
|
||||
import torch
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from sklearn import preprocessing
|
||||
from torchvision.transforms import v2 as transforms
|
||||
import stable_worldmodel as swm
|
||||
|
||||
def img_transform(cfg):
|
||||
transform = transforms.Compose(
|
||||
[
|
||||
transforms.ToImage(),
|
||||
transforms.ToDtype(torch.float32, scale=True),
|
||||
transforms.Normalize(**spt.data.dataset_stats.ImageNet),
|
||||
transforms.Resize(size=cfg.eval.img_size),
|
||||
]
|
||||
)
|
||||
return transform
|
||||
|
||||
|
||||
def get_episodes_length(dataset, episodes):
|
||||
col_name = "episode_idx" if "episode_idx" in dataset.column_names else "ep_idx"
|
||||
|
||||
episode_idx = dataset.get_col_data(col_name)
|
||||
step_idx = dataset.get_col_data("step_idx")
|
||||
lengths = []
|
||||
for ep_id in episodes:
|
||||
lengths.append(np.max(step_idx[episode_idx == ep_id]) + 1)
|
||||
return np.array(lengths)
|
||||
|
||||
|
||||
def get_dataset(cfg, dataset_name):
|
||||
dataset_path = Path(cfg.cache_dir or swm.data.utils.get_cache_dir())
|
||||
dataset = swm.data.HDF5Dataset(
|
||||
dataset_name,
|
||||
keys_to_cache=cfg.dataset.keys_to_cache,
|
||||
cache_dir=dataset_path,
|
||||
)
|
||||
return dataset
|
||||
|
||||
@hydra.main(version_base=None, config_path="./config/eval", config_name="pusht")
|
||||
def run(cfg: DictConfig):
|
||||
"""Run evaluation of dinowm vs random policy."""
|
||||
assert (
|
||||
cfg.plan_config.horizon * cfg.plan_config.action_block <= cfg.eval.eval_budget
|
||||
), "Planning horizon must be smaller than or equal to eval_budget"
|
||||
|
||||
# create world environment
|
||||
cfg.world.max_episode_steps = 2 * cfg.eval.eval_budget
|
||||
world = swm.World(**cfg.world, image_shape=(224, 224))
|
||||
|
||||
# create the transform
|
||||
transform = {
|
||||
"pixels": img_transform(cfg),
|
||||
"goal": img_transform(cfg),
|
||||
}
|
||||
|
||||
dataset = get_dataset(cfg, cfg.eval.dataset_name)
|
||||
stats_dataset = dataset # get_dataset(cfg, cfg.dataset.stats)
|
||||
col_name = "episode_idx" if "episode_idx" in dataset.column_names else "ep_idx"
|
||||
ep_indices, _ = np.unique(stats_dataset.get_col_data(col_name), return_index=True)
|
||||
|
||||
process = {}
|
||||
for col in cfg.dataset.keys_to_cache:
|
||||
if col in ["pixels"]:
|
||||
continue
|
||||
processor = preprocessing.StandardScaler()
|
||||
col_data = stats_dataset.get_col_data(col)
|
||||
col_data = col_data[~np.isnan(col_data).any(axis=1)]
|
||||
processor.fit(col_data)
|
||||
process[col] = processor
|
||||
|
||||
if col != "action":
|
||||
process[f"goal_{col}"] = process[col]
|
||||
|
||||
# -- run evaluation
|
||||
policy = cfg.get("policy", "random")
|
||||
|
||||
if policy != "random":
|
||||
model = swm.policy.AutoCostModel(cfg.policy)
|
||||
model = model.to("cuda")
|
||||
model = model.eval()
|
||||
model.requires_grad_(False)
|
||||
model.interpolate_pos_encoding = True
|
||||
config = swm.PlanConfig(**cfg.plan_config)
|
||||
solver = hydra.utils.instantiate(cfg.solver, model=model)
|
||||
policy = swm.policy.WorldModelPolicy(
|
||||
solver=solver, config=config, process=process, transform=transform
|
||||
)
|
||||
|
||||
else:
|
||||
policy = swm.policy.RandomPolicy()
|
||||
|
||||
results_path = (
|
||||
Path(swm.data.utils.get_cache_dir(), cfg.policy).parent
|
||||
if cfg.policy != "random"
|
||||
else Path(__file__).parent
|
||||
)
|
||||
|
||||
# sample the episodes and the starting indices
|
||||
episode_len = get_episodes_length(dataset, ep_indices)
|
||||
max_start_idx = episode_len - cfg.eval.goal_offset_steps - 1
|
||||
max_start_idx_dict = {ep_id: max_start_idx[i] for i, ep_id in enumerate(ep_indices)}
|
||||
# Map each dataset row’s episode_idx to its max_start_idx
|
||||
col_name = "episode_idx" if "episode_idx" in dataset.column_names else "ep_idx"
|
||||
max_start_per_row = np.array(
|
||||
[max_start_idx_dict[ep_id] for ep_id in dataset.get_col_data(col_name)]
|
||||
)
|
||||
|
||||
# remove all the lines of dataset for which dataset['step_idx'] > max_start_per_row
|
||||
valid_mask = dataset.get_col_data("step_idx") <= max_start_per_row
|
||||
valid_indices = np.nonzero(valid_mask)[0]
|
||||
print(valid_mask.sum(), "valid starting points found for evaluation.")
|
||||
|
||||
g = np.random.default_rng(cfg.seed)
|
||||
random_episode_indices = g.choice(
|
||||
len(valid_indices) - 1, size=cfg.eval.num_eval, replace=False
|
||||
)
|
||||
|
||||
# sort increasingly to avoid issues with HDF5Dataset indexing
|
||||
random_episode_indices = np.sort(valid_indices[random_episode_indices])
|
||||
|
||||
print(random_episode_indices)
|
||||
|
||||
eval_episodes = dataset.get_row_data(random_episode_indices)[col_name]
|
||||
eval_start_idx = dataset.get_row_data(random_episode_indices)["step_idx"]
|
||||
|
||||
if len(eval_episodes) < cfg.eval.num_eval:
|
||||
raise ValueError("Not enough episodes with sufficient length for evaluation.")
|
||||
|
||||
world.set_policy(policy)
|
||||
|
||||
start_time = time.time()
|
||||
metrics = world.evaluate_from_dataset(
|
||||
dataset,
|
||||
start_steps=eval_start_idx.tolist(),
|
||||
goal_offset_steps=cfg.eval.goal_offset_steps,
|
||||
eval_budget=cfg.eval.eval_budget,
|
||||
episodes_idx=eval_episodes.tolist(),
|
||||
callables=OmegaConf.to_container(cfg.eval.get("callables"), resolve=True),
|
||||
video_path=results_path,
|
||||
)
|
||||
end_time = time.time()
|
||||
|
||||
print(metrics)
|
||||
|
||||
results_path = results_path / cfg.output.filename
|
||||
results_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with results_path.open("a") as f:
|
||||
f.write("\n") # separate from previous runs
|
||||
|
||||
f.write("==== CONFIG ====\n")
|
||||
f.write(OmegaConf.to_yaml(cfg))
|
||||
f.write("\n")
|
||||
|
||||
f.write("==== RESULTS ====\n")
|
||||
f.write(f"metrics: {metrics}\n")
|
||||
f.write(f"evaluation_time: {end_time - start_time} seconds\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run()
|
||||
Reference in New Issue
Block a user