Files
lewm/eval.py
qihuanye 8ba5bc8b0b 多卡
2026-04-10 03:13:54 +00:00

558 lines
18 KiB
Python

import os
os.environ["MUJOCO_GL"] = "egl"
import multiprocessing as mp
import time
import traceback
from contextlib import nullcontext
from pathlib import Path
import hydra
import numpy as np
import stable_pretraining as spt
import torch
from omegaconf import DictConfig, OmegaConf
from sklearn import preprocessing
from torchvision.transforms import v2 as transforms
import stable_worldmodel as swm
def img_transform(cfg):
transform = transforms.Compose(
[
transforms.ToImage(),
transforms.ToDtype(torch.float32, scale=True),
transforms.Normalize(**spt.data.dataset_stats.ImageNet),
transforms.Resize(size=cfg.eval.img_size),
]
)
return transform
def get_episodes_length(dataset, episodes):
col_name = "episode_idx" if "episode_idx" in dataset.column_names else "ep_idx"
episode_idx = dataset.get_col_data(col_name)
step_idx = dataset.get_col_data("step_idx")
lengths = []
for ep_id in episodes:
lengths.append(np.max(step_idx[episode_idx == ep_id]) + 1)
return np.array(lengths)
def get_dataset(cfg, dataset_name):
dataset_path = Path(cfg.cache_dir or swm.data.utils.get_cache_dir())
dataset = swm.data.HDF5Dataset(
dataset_name,
keys_to_cache=cfg.dataset.keys_to_cache,
cache_dir=dataset_path,
)
return dataset
def get_profile_cfg(cfg):
profile_cfg = {
"enabled": False,
"trace_dirname": "torch_profile",
"record_shapes": True,
"profile_memory": True,
"with_stack": False,
"with_flops": False,
"row_limit": 40,
"worker_name": "eval",
"export_chrome_trace": True,
"export_tensorboard": True,
}
cfg_profile = cfg.get("profile")
if cfg_profile is not None:
profile_cfg.update(OmegaConf.to_container(cfg_profile, resolve=True))
return profile_cfg
def get_compile_cfg(cfg):
compile_cfg = {
"enabled": True,
"target": "predictor",
"mode": "reduce-overhead",
"fullgraph": False,
"dynamic": False,
"cuda_only": True,
}
cfg_compile = cfg.get("compile")
if cfg_compile is not None:
compile_cfg.update(OmegaConf.to_container(cfg_compile, resolve=True))
return compile_cfg
def maybe_compile_inference_target(model, cfg, device):
compile_cfg = get_compile_cfg(cfg)
compile_target = "disabled"
if not compile_cfg["enabled"]:
return model, compile_cfg, compile_target
if not hasattr(torch, "compile"):
print("torch.compile is unavailable, skipping inference compilation.")
return model, compile_cfg, compile_target
if compile_cfg["cuda_only"] and not str(device).startswith("cuda"):
print("Skipping torch.compile because compile.cuda_only=true and device is not CUDA.")
return model, compile_cfg, compile_target
target = str(compile_cfg["target"]).lower()
compile_kwargs = {
"mode": compile_cfg["mode"],
"fullgraph": compile_cfg["fullgraph"],
"dynamic": compile_cfg["dynamic"],
}
if target == "predictor":
if not hasattr(model, "predictor"):
print("Requested compile target 'predictor' is unavailable on the model.")
return model, compile_cfg, compile_target
model.predictor = torch.compile(model.predictor, **compile_kwargs)
compile_target = "predictor"
elif target == "predict":
if not hasattr(model, "predict"):
print("Requested compile target 'predict' is unavailable on the model.")
return model, compile_cfg, compile_target
model.predict = torch.compile(model.predict, **compile_kwargs)
compile_target = "predict"
else:
print(
f"Unsupported compile.target={target}. Expected one of: predictor, predict."
)
return model, compile_cfg, compile_target
def get_inference_context(cfg, device):
precision = str(cfg.get("inference_precision", "fp32")).lower()
device_type = "cuda" if device.startswith("cuda") else "cpu"
if precision == "fp32":
return nullcontext(), "fp32"
if precision in {"bf16", "bfloat16"}:
return (
torch.autocast(device_type=device_type, dtype=torch.bfloat16),
"bf16",
)
if precision in {"fp16", "float16"}:
if device_type != "cuda":
print("fp16 inference is only supported on CUDA, falling back to fp32.")
return nullcontext(), "fp32"
return (
torch.autocast(device_type=device_type, dtype=torch.float16),
"fp16",
)
raise ValueError(
f"Unsupported inference_precision={precision}. Expected one of: fp32, bf16, fp16."
)
def make_profiler(cfg, results_path):
profile_cfg = get_profile_cfg(cfg)
if not profile_cfg["enabled"]:
return nullcontext(), None, profile_cfg
activities = [torch.profiler.ProfilerActivity.CPU]
if torch.cuda.is_available():
activities.append(torch.profiler.ProfilerActivity.CUDA)
profile_dir = results_path / profile_cfg["trace_dirname"]
profile_dir.mkdir(parents=True, exist_ok=True)
profiler = torch.profiler.profile(
activities=activities,
record_shapes=profile_cfg["record_shapes"],
profile_memory=profile_cfg["profile_memory"],
with_stack=profile_cfg["with_stack"],
with_flops=profile_cfg["with_flops"],
)
return profiler, profile_dir, profile_cfg
def dump_profiler_results(profiler, profile_dir, profile_cfg):
if profiler is None or profile_dir is None:
return None
has_cuda = torch.cuda.is_available()
table = profiler.key_averages().table(
sort_by="self_cuda_time_total" if has_cuda else "self_cpu_time_total",
row_limit=profile_cfg["row_limit"],
)
summary_path = profile_dir / "key_averages.txt"
summary_path.write_text(table)
if profile_cfg["export_tensorboard"]:
trace_handler = torch.profiler.tensorboard_trace_handler(
str(profile_dir), worker_name=profile_cfg["worker_name"]
)
trace_handler(profiler)
elif profile_cfg["export_chrome_trace"]:
profiler.export_chrome_trace(str(profile_dir / "trace.json"))
return summary_path
def get_multi_gpu_cfg(cfg):
multi_gpu_cfg = {
"enabled": False,
"devices": None,
"start_method": "spawn",
}
cfg_multi_gpu = cfg.get("multi_gpu")
if cfg_multi_gpu is not None:
multi_gpu_cfg.update(OmegaConf.to_container(cfg_multi_gpu, resolve=True))
return multi_gpu_cfg
def build_process(cfg, dataset):
process = {}
for col in cfg.dataset.keys_to_cache:
if col in ["pixels"]:
continue
processor = preprocessing.StandardScaler()
col_data = dataset.get_col_data(col)
col_data = col_data[~np.isnan(col_data).any(axis=1)]
processor.fit(col_data)
process[col] = processor
if col != "action":
process[f"goal_{col}"] = process[col]
return process
def sample_eval_cases(cfg, dataset):
stats_dataset = dataset
col_name = "episode_idx" if "episode_idx" in dataset.column_names else "ep_idx"
ep_indices, _ = np.unique(stats_dataset.get_col_data(col_name), return_index=True)
episode_len = get_episodes_length(dataset, ep_indices)
max_start_idx = episode_len - cfg.eval.goal_offset_steps - 1
max_start_idx_dict = {ep_id: max_start_idx[i] for i, ep_id in enumerate(ep_indices)}
max_start_per_row = np.array(
[max_start_idx_dict[ep_id] for ep_id in dataset.get_col_data(col_name)]
)
valid_mask = dataset.get_col_data("step_idx") <= max_start_per_row
valid_indices = np.nonzero(valid_mask)[0]
print(valid_mask.sum(), "valid starting points found for evaluation.")
g = np.random.default_rng(cfg.seed)
random_episode_indices = g.choice(
len(valid_indices) - 1, size=cfg.eval.num_eval, replace=False
)
random_episode_indices = np.sort(valid_indices[random_episode_indices])
print(random_episode_indices)
rows = dataset.get_row_data(random_episode_indices)
eval_episodes = rows[col_name]
eval_start_idx = rows["step_idx"]
if len(eval_episodes) < cfg.eval.num_eval:
raise ValueError("Not enough episodes with sufficient length for evaluation.")
return eval_episodes, eval_start_idx
def normalize_multi_gpu_devices(devices):
if devices is None:
return [f"cuda:{idx}" for idx in range(torch.cuda.device_count())]
normalized = []
for device in devices:
if isinstance(device, int):
normalized.append(f"cuda:{device}")
elif isinstance(device, str) and device.isdigit():
normalized.append(f"cuda:{int(device)}")
else:
normalized.append(str(device))
return normalized
def shard_eval_cases(eval_episodes, eval_start_idx, num_shards):
if num_shards < 1:
raise ValueError("num_shards must be >= 1")
total = len(eval_episodes)
shard_sizes = [total // num_shards] * num_shards
for idx in range(total % num_shards):
shard_sizes[idx] += 1
shards = []
start = 0
for size in shard_sizes:
end = start + size
if size > 0:
shards.append((eval_episodes[start:end], eval_start_idx[start:end]))
start = end
return shards
def run_eval_subset(
cfg: DictConfig,
eval_episodes,
eval_start_idx,
output_dir: Path,
*,
device_override: str | None = None,
enable_profile: bool = True,
):
local_cfg = OmegaConf.create(OmegaConf.to_container(cfg, resolve=False))
local_cfg.eval.num_eval = len(eval_episodes)
local_cfg.world.num_envs = len(eval_episodes)
local_cfg.world.max_episode_steps = 2 * local_cfg.eval.eval_budget
if device_override is not None:
local_cfg.solver.device = device_override
if torch.cuda.is_available() and str(device_override).startswith("cuda"):
torch.cuda.set_device(torch.device(device_override))
if not enable_profile:
if local_cfg.get("profile") is None:
local_cfg.profile = OmegaConf.create({"enabled": False})
else:
local_cfg.profile.enabled = False
world = swm.World(**local_cfg.world, image_shape=(224, 224))
transform = {
"pixels": img_transform(local_cfg),
"goal": img_transform(local_cfg),
}
dataset = get_dataset(local_cfg, local_cfg.eval.dataset_name)
process = build_process(local_cfg, dataset)
policy_name = local_cfg.get("policy", "random")
if policy_name != "random":
model = swm.policy.AutoCostModel(local_cfg.policy)
device = device_override or ("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
model = model.eval()
model.requires_grad_(False)
model, compile_cfg, compile_target = maybe_compile_inference_target(
model, local_cfg, device
)
inference_ctx, inference_precision = get_inference_context(local_cfg, device)
model.interpolate_pos_encoding = True
config = swm.PlanConfig(**local_cfg.plan_config)
solver = hydra.utils.instantiate(local_cfg.solver, model=model)
policy = swm.policy.WorldModelPolicy(
solver=solver, config=config, process=process, transform=transform
)
else:
policy = swm.policy.RandomPolicy()
inference_ctx = nullcontext()
inference_precision = "fp32"
compile_cfg = get_compile_cfg(local_cfg)
compile_target = "disabled"
device = device_override or ("cuda" if torch.cuda.is_available() else "cpu")
profiler_ctx, profile_dir, profile_cfg = make_profiler(local_cfg, output_dir)
world.set_policy(policy)
if str(device).startswith("cuda") and torch.cuda.is_available():
torch.cuda.synchronize()
start_time = time.time()
with torch.inference_mode():
with profiler_ctx as profiler:
with inference_ctx:
with torch.profiler.record_function("eval.world_evaluate_from_dataset"):
metrics = world.evaluate_from_dataset(
dataset,
start_steps=list(eval_start_idx),
goal_offset_steps=local_cfg.eval.goal_offset_steps,
eval_budget=local_cfg.eval.eval_budget,
episodes_idx=list(eval_episodes),
callables=OmegaConf.to_container(
local_cfg.eval.get("callables"), resolve=True
),
save_video=False,
video_path=output_dir,
)
if str(device).startswith("cuda") and torch.cuda.is_available():
torch.cuda.synchronize()
evaluation_time = time.time() - start_time
profile_summary_path = dump_profiler_results(profiler, profile_dir, profile_cfg)
return {
"metrics": metrics,
"evaluation_time": evaluation_time,
"inference_precision": inference_precision,
"compile_target": compile_target,
"compile_mode": compile_cfg["mode"] if compile_target != "disabled" else None,
"profile_dir": profile_dir,
"profile_summary_path": profile_summary_path,
}
def _multi_gpu_eval_worker(
cfg_container,
eval_episodes,
eval_start_idx,
output_dir,
device,
shard_idx,
queue,
):
try:
cfg = OmegaConf.create(cfg_container)
result = run_eval_subset(
cfg,
eval_episodes,
eval_start_idx,
Path(output_dir),
device_override=device,
enable_profile=False,
)
queue.put({"ok": True, "shard_idx": shard_idx, "result": result})
except Exception:
queue.put(
{
"ok": False,
"shard_idx": shard_idx,
"error": traceback.format_exc(),
}
)
def run_multi_gpu_eval(cfg, eval_episodes, eval_start_idx, output_dir: Path):
multi_gpu_cfg = get_multi_gpu_cfg(cfg)
devices = normalize_multi_gpu_devices(multi_gpu_cfg["devices"])
if len(devices) < 2:
raise ValueError("multi_gpu.enabled=true requires at least 2 CUDA devices")
shards = shard_eval_cases(eval_episodes, eval_start_idx, min(len(devices), len(eval_episodes)))
devices = devices[: len(shards)]
ctx = mp.get_context(multi_gpu_cfg["start_method"])
queue = ctx.Queue()
cfg_container = OmegaConf.to_container(cfg, resolve=False)
processes = []
start_time = time.time()
for shard_idx, ((shard_episodes, shard_start_idx), device) in enumerate(
zip(shards, devices, strict=True)
):
process = ctx.Process(
target=_multi_gpu_eval_worker,
args=(
cfg_container,
list(shard_episodes),
list(shard_start_idx),
str(output_dir),
device,
shard_idx,
queue,
),
)
process.start()
processes.append(process)
shard_results = {}
errors = []
for _ in processes:
message = queue.get()
if message["ok"]:
shard_results[message["shard_idx"]] = message["result"]
else:
errors.append(message["error"])
for process in processes:
process.join()
if errors:
raise RuntimeError(errors[0])
ordered_results = [shard_results[idx] for idx in range(len(processes))]
episode_successes = np.concatenate(
[
np.asarray(result["metrics"]["episode_successes"], dtype=np.bool_)
for result in ordered_results
]
)
seeds = None
shard_seeds = [result["metrics"].get("seeds") for result in ordered_results]
if all(seed is not None for seed in shard_seeds):
seeds = np.concatenate(shard_seeds)
metrics = {
"success_rate": float(np.sum(episode_successes)) / len(episode_successes) * 100.0,
"episode_successes": episode_successes,
"seeds": seeds,
}
reference = ordered_results[0]
return {
"metrics": metrics,
"evaluation_time": time.time() - start_time,
"inference_precision": reference["inference_precision"],
"compile_target": reference["compile_target"],
"compile_mode": reference["compile_mode"],
"profile_dir": None,
"profile_summary_path": None,
}
@hydra.main(version_base=None, config_path="./config/eval", config_name="pusht")
def run(cfg: DictConfig):
"""Run evaluation of dinowm vs random policy."""
assert (
cfg.plan_config.horizon * cfg.plan_config.action_block <= cfg.eval.eval_budget
), "Planning horizon must be smaller than or equal to eval_budget"
dataset = get_dataset(cfg, cfg.eval.dataset_name)
eval_episodes, eval_start_idx = sample_eval_cases(cfg, dataset)
output_dir = Path.cwd().resolve()
profile_cfg = get_profile_cfg(cfg)
if get_multi_gpu_cfg(cfg)["enabled"]:
if profile_cfg["enabled"]:
raise ValueError("Profiling is not supported together with multi_gpu.enabled=true")
eval_result = run_multi_gpu_eval(cfg, eval_episodes, eval_start_idx, output_dir)
else:
eval_result = run_eval_subset(
cfg,
eval_episodes.tolist(),
eval_start_idx.tolist(),
output_dir,
)
metrics = eval_result["metrics"]
evaluation_time = eval_result["evaluation_time"]
inference_precision = eval_result["inference_precision"]
compile_target = eval_result["compile_target"]
compile_mode = eval_result["compile_mode"]
profile_dir = eval_result["profile_dir"]
profile_summary_path = eval_result["profile_summary_path"]
print(metrics)
results_path = output_dir / cfg.output.filename
results_path.parent.mkdir(parents=True, exist_ok=True)
with results_path.open("a") as f:
f.write("\n") # separate from previous runs
f.write("==== CONFIG ====\n")
f.write(OmegaConf.to_yaml(cfg))
f.write("\n")
f.write("==== RESULTS ====\n")
f.write(f"metrics: {metrics}\n")
f.write(f"evaluation_time: {evaluation_time} seconds\n")
f.write(f"inference_precision: {inference_precision}\n")
f.write(f"inference_compile_target: {compile_target}\n")
if compile_target != "disabled":
f.write(f"inference_compile_mode: {compile_mode}\n")
if profile_cfg["enabled"]:
f.write(f"profile_dir: {profile_dir}\n")
if profile_summary_path is not None:
f.write(f"profile_summary: {profile_summary_path}\n")
if __name__ == "__main__":
run()