import os os.environ["MUJOCO_GL"] = "egl" import multiprocessing as mp import time import traceback from contextlib import nullcontext from pathlib import Path import hydra import numpy as np import stable_pretraining as spt import torch from omegaconf import DictConfig, OmegaConf from sklearn import preprocessing from torchvision.transforms import v2 as transforms import stable_worldmodel as swm def img_transform(cfg): transform = transforms.Compose( [ transforms.ToImage(), transforms.ToDtype(torch.float32, scale=True), transforms.Normalize(**spt.data.dataset_stats.ImageNet), transforms.Resize(size=cfg.eval.img_size), ] ) return transform def get_episodes_length(dataset, episodes): col_name = "episode_idx" if "episode_idx" in dataset.column_names else "ep_idx" episode_idx = dataset.get_col_data(col_name) step_idx = dataset.get_col_data("step_idx") lengths = [] for ep_id in episodes: lengths.append(np.max(step_idx[episode_idx == ep_id]) + 1) return np.array(lengths) def get_dataset(cfg, dataset_name): dataset_path = Path(cfg.cache_dir or swm.data.utils.get_cache_dir()) dataset = swm.data.HDF5Dataset( dataset_name, keys_to_cache=cfg.dataset.keys_to_cache, cache_dir=dataset_path, ) return dataset def get_profile_cfg(cfg): profile_cfg = { "enabled": False, "trace_dirname": "torch_profile", "record_shapes": True, "profile_memory": True, "with_stack": False, "with_flops": False, "row_limit": 40, "worker_name": "eval", "export_chrome_trace": True, "export_tensorboard": True, } cfg_profile = cfg.get("profile") if cfg_profile is not None: profile_cfg.update(OmegaConf.to_container(cfg_profile, resolve=True)) return profile_cfg def get_compile_cfg(cfg): compile_cfg = { "enabled": True, "target": "predictor", "mode": "reduce-overhead", "fullgraph": False, "dynamic": False, "cuda_only": True, } cfg_compile = cfg.get("compile") if cfg_compile is not None: compile_cfg.update(OmegaConf.to_container(cfg_compile, resolve=True)) return compile_cfg def maybe_compile_inference_target(model, cfg, device): compile_cfg = get_compile_cfg(cfg) compile_target = "disabled" if not compile_cfg["enabled"]: return model, compile_cfg, compile_target if not hasattr(torch, "compile"): print("torch.compile is unavailable, skipping inference compilation.") return model, compile_cfg, compile_target if compile_cfg["cuda_only"] and not str(device).startswith("cuda"): print("Skipping torch.compile because compile.cuda_only=true and device is not CUDA.") return model, compile_cfg, compile_target target = str(compile_cfg["target"]).lower() compile_kwargs = { "mode": compile_cfg["mode"], "fullgraph": compile_cfg["fullgraph"], "dynamic": compile_cfg["dynamic"], } if target == "predictor": if not hasattr(model, "predictor"): print("Requested compile target 'predictor' is unavailable on the model.") return model, compile_cfg, compile_target model.predictor = torch.compile(model.predictor, **compile_kwargs) compile_target = "predictor" elif target == "predict": if not hasattr(model, "predict"): print("Requested compile target 'predict' is unavailable on the model.") return model, compile_cfg, compile_target model.predict = torch.compile(model.predict, **compile_kwargs) compile_target = "predict" else: print( f"Unsupported compile.target={target}. Expected one of: predictor, predict." ) return model, compile_cfg, compile_target def get_inference_context(cfg, device): precision = str(cfg.get("inference_precision", "fp32")).lower() device_type = "cuda" if device.startswith("cuda") else "cpu" if precision == "fp32": return nullcontext(), "fp32" if precision in {"bf16", "bfloat16"}: return ( torch.autocast(device_type=device_type, dtype=torch.bfloat16), "bf16", ) if precision in {"fp16", "float16"}: if device_type != "cuda": print("fp16 inference is only supported on CUDA, falling back to fp32.") return nullcontext(), "fp32" return ( torch.autocast(device_type=device_type, dtype=torch.float16), "fp16", ) raise ValueError( f"Unsupported inference_precision={precision}. Expected one of: fp32, bf16, fp16." ) def make_profiler(cfg, results_path): profile_cfg = get_profile_cfg(cfg) if not profile_cfg["enabled"]: return nullcontext(), None, profile_cfg activities = [torch.profiler.ProfilerActivity.CPU] if torch.cuda.is_available(): activities.append(torch.profiler.ProfilerActivity.CUDA) profile_dir = results_path / profile_cfg["trace_dirname"] profile_dir.mkdir(parents=True, exist_ok=True) profiler = torch.profiler.profile( activities=activities, record_shapes=profile_cfg["record_shapes"], profile_memory=profile_cfg["profile_memory"], with_stack=profile_cfg["with_stack"], with_flops=profile_cfg["with_flops"], ) return profiler, profile_dir, profile_cfg def dump_profiler_results(profiler, profile_dir, profile_cfg): if profiler is None or profile_dir is None: return None has_cuda = torch.cuda.is_available() table = profiler.key_averages().table( sort_by="self_cuda_time_total" if has_cuda else "self_cpu_time_total", row_limit=profile_cfg["row_limit"], ) summary_path = profile_dir / "key_averages.txt" summary_path.write_text(table) if profile_cfg["export_tensorboard"]: trace_handler = torch.profiler.tensorboard_trace_handler( str(profile_dir), worker_name=profile_cfg["worker_name"] ) trace_handler(profiler) elif profile_cfg["export_chrome_trace"]: profiler.export_chrome_trace(str(profile_dir / "trace.json")) return summary_path def get_multi_gpu_cfg(cfg): multi_gpu_cfg = { "enabled": False, "devices": None, "start_method": "spawn", } cfg_multi_gpu = cfg.get("multi_gpu") if cfg_multi_gpu is not None: multi_gpu_cfg.update(OmegaConf.to_container(cfg_multi_gpu, resolve=True)) return multi_gpu_cfg def build_process(cfg, dataset): process = {} for col in cfg.dataset.keys_to_cache: if col in ["pixels"]: continue processor = preprocessing.StandardScaler() col_data = dataset.get_col_data(col) col_data = col_data[~np.isnan(col_data).any(axis=1)] processor.fit(col_data) process[col] = processor if col != "action": process[f"goal_{col}"] = process[col] return process def sample_eval_cases(cfg, dataset): stats_dataset = dataset col_name = "episode_idx" if "episode_idx" in dataset.column_names else "ep_idx" ep_indices, _ = np.unique(stats_dataset.get_col_data(col_name), return_index=True) episode_len = get_episodes_length(dataset, ep_indices) max_start_idx = episode_len - cfg.eval.goal_offset_steps - 1 max_start_idx_dict = {ep_id: max_start_idx[i] for i, ep_id in enumerate(ep_indices)} max_start_per_row = np.array( [max_start_idx_dict[ep_id] for ep_id in dataset.get_col_data(col_name)] ) valid_mask = dataset.get_col_data("step_idx") <= max_start_per_row valid_indices = np.nonzero(valid_mask)[0] print(valid_mask.sum(), "valid starting points found for evaluation.") g = np.random.default_rng(cfg.seed) random_episode_indices = g.choice( len(valid_indices) - 1, size=cfg.eval.num_eval, replace=False ) random_episode_indices = np.sort(valid_indices[random_episode_indices]) print(random_episode_indices) rows = dataset.get_row_data(random_episode_indices) eval_episodes = rows[col_name] eval_start_idx = rows["step_idx"] if len(eval_episodes) < cfg.eval.num_eval: raise ValueError("Not enough episodes with sufficient length for evaluation.") return eval_episodes, eval_start_idx def normalize_multi_gpu_devices(devices): if devices is None: return [f"cuda:{idx}" for idx in range(torch.cuda.device_count())] normalized = [] for device in devices: if isinstance(device, int): normalized.append(f"cuda:{device}") elif isinstance(device, str) and device.isdigit(): normalized.append(f"cuda:{int(device)}") else: normalized.append(str(device)) return normalized def shard_eval_cases(eval_episodes, eval_start_idx, num_shards): if num_shards < 1: raise ValueError("num_shards must be >= 1") total = len(eval_episodes) shard_sizes = [total // num_shards] * num_shards for idx in range(total % num_shards): shard_sizes[idx] += 1 shards = [] start = 0 for size in shard_sizes: end = start + size if size > 0: shards.append((eval_episodes[start:end], eval_start_idx[start:end])) start = end return shards def run_eval_subset( cfg: DictConfig, eval_episodes, eval_start_idx, output_dir: Path, *, device_override: str | None = None, enable_profile: bool = True, ): local_cfg = OmegaConf.create(OmegaConf.to_container(cfg, resolve=False)) local_cfg.eval.num_eval = len(eval_episodes) local_cfg.world.num_envs = len(eval_episodes) local_cfg.world.max_episode_steps = 2 * local_cfg.eval.eval_budget if device_override is not None: local_cfg.solver.device = device_override if torch.cuda.is_available() and str(device_override).startswith("cuda"): torch.cuda.set_device(torch.device(device_override)) if not enable_profile: if local_cfg.get("profile") is None: local_cfg.profile = OmegaConf.create({"enabled": False}) else: local_cfg.profile.enabled = False world = swm.World(**local_cfg.world, image_shape=(224, 224)) transform = { "pixels": img_transform(local_cfg), "goal": img_transform(local_cfg), } dataset = get_dataset(local_cfg, local_cfg.eval.dataset_name) process = build_process(local_cfg, dataset) policy_name = local_cfg.get("policy", "random") if policy_name != "random": model = swm.policy.AutoCostModel(local_cfg.policy) device = device_override or ("cuda" if torch.cuda.is_available() else "cpu") model = model.to(device) model = model.eval() model.requires_grad_(False) model, compile_cfg, compile_target = maybe_compile_inference_target( model, local_cfg, device ) inference_ctx, inference_precision = get_inference_context(local_cfg, device) model.interpolate_pos_encoding = True config = swm.PlanConfig(**local_cfg.plan_config) solver = hydra.utils.instantiate(local_cfg.solver, model=model) policy = swm.policy.WorldModelPolicy( solver=solver, config=config, process=process, transform=transform ) else: policy = swm.policy.RandomPolicy() inference_ctx = nullcontext() inference_precision = "fp32" compile_cfg = get_compile_cfg(local_cfg) compile_target = "disabled" device = device_override or ("cuda" if torch.cuda.is_available() else "cpu") profiler_ctx, profile_dir, profile_cfg = make_profiler(local_cfg, output_dir) world.set_policy(policy) if str(device).startswith("cuda") and torch.cuda.is_available(): torch.cuda.synchronize() start_time = time.time() with torch.inference_mode(): with profiler_ctx as profiler: with inference_ctx: with torch.profiler.record_function("eval.world_evaluate_from_dataset"): metrics = world.evaluate_from_dataset( dataset, start_steps=list(eval_start_idx), goal_offset_steps=local_cfg.eval.goal_offset_steps, eval_budget=local_cfg.eval.eval_budget, episodes_idx=list(eval_episodes), callables=OmegaConf.to_container( local_cfg.eval.get("callables"), resolve=True ), save_video=False, video_path=output_dir, ) if str(device).startswith("cuda") and torch.cuda.is_available(): torch.cuda.synchronize() evaluation_time = time.time() - start_time profile_summary_path = dump_profiler_results(profiler, profile_dir, profile_cfg) return { "metrics": metrics, "evaluation_time": evaluation_time, "inference_precision": inference_precision, "compile_target": compile_target, "compile_mode": compile_cfg["mode"] if compile_target != "disabled" else None, "profile_dir": profile_dir, "profile_summary_path": profile_summary_path, } def _multi_gpu_eval_worker( cfg_container, eval_episodes, eval_start_idx, output_dir, device, shard_idx, queue, ): try: cfg = OmegaConf.create(cfg_container) result = run_eval_subset( cfg, eval_episodes, eval_start_idx, Path(output_dir), device_override=device, enable_profile=False, ) queue.put({"ok": True, "shard_idx": shard_idx, "result": result}) except Exception: queue.put( { "ok": False, "shard_idx": shard_idx, "error": traceback.format_exc(), } ) def run_multi_gpu_eval(cfg, eval_episodes, eval_start_idx, output_dir: Path): multi_gpu_cfg = get_multi_gpu_cfg(cfg) devices = normalize_multi_gpu_devices(multi_gpu_cfg["devices"]) if len(devices) < 2: raise ValueError("multi_gpu.enabled=true requires at least 2 CUDA devices") shards = shard_eval_cases(eval_episodes, eval_start_idx, min(len(devices), len(eval_episodes))) devices = devices[: len(shards)] ctx = mp.get_context(multi_gpu_cfg["start_method"]) queue = ctx.Queue() cfg_container = OmegaConf.to_container(cfg, resolve=False) processes = [] start_time = time.time() for shard_idx, ((shard_episodes, shard_start_idx), device) in enumerate( zip(shards, devices, strict=True) ): process = ctx.Process( target=_multi_gpu_eval_worker, args=( cfg_container, list(shard_episodes), list(shard_start_idx), str(output_dir), device, shard_idx, queue, ), ) process.start() processes.append(process) shard_results = {} errors = [] for _ in processes: message = queue.get() if message["ok"]: shard_results[message["shard_idx"]] = message["result"] else: errors.append(message["error"]) for process in processes: process.join() if errors: raise RuntimeError(errors[0]) ordered_results = [shard_results[idx] for idx in range(len(processes))] episode_successes = np.concatenate( [ np.asarray(result["metrics"]["episode_successes"], dtype=np.bool_) for result in ordered_results ] ) seeds = None shard_seeds = [result["metrics"].get("seeds") for result in ordered_results] if all(seed is not None for seed in shard_seeds): seeds = np.concatenate(shard_seeds) metrics = { "success_rate": float(np.sum(episode_successes)) / len(episode_successes) * 100.0, "episode_successes": episode_successes, "seeds": seeds, } reference = ordered_results[0] return { "metrics": metrics, "evaluation_time": time.time() - start_time, "inference_precision": reference["inference_precision"], "compile_target": reference["compile_target"], "compile_mode": reference["compile_mode"], "profile_dir": None, "profile_summary_path": None, } @hydra.main(version_base=None, config_path="./config/eval", config_name="pusht") def run(cfg: DictConfig): """Run evaluation of dinowm vs random policy.""" assert ( cfg.plan_config.horizon * cfg.plan_config.action_block <= cfg.eval.eval_budget ), "Planning horizon must be smaller than or equal to eval_budget" dataset = get_dataset(cfg, cfg.eval.dataset_name) eval_episodes, eval_start_idx = sample_eval_cases(cfg, dataset) output_dir = Path.cwd().resolve() profile_cfg = get_profile_cfg(cfg) if get_multi_gpu_cfg(cfg)["enabled"]: if profile_cfg["enabled"]: raise ValueError("Profiling is not supported together with multi_gpu.enabled=true") eval_result = run_multi_gpu_eval(cfg, eval_episodes, eval_start_idx, output_dir) else: eval_result = run_eval_subset( cfg, eval_episodes.tolist(), eval_start_idx.tolist(), output_dir, ) metrics = eval_result["metrics"] evaluation_time = eval_result["evaluation_time"] inference_precision = eval_result["inference_precision"] compile_target = eval_result["compile_target"] compile_mode = eval_result["compile_mode"] profile_dir = eval_result["profile_dir"] profile_summary_path = eval_result["profile_summary_path"] print(metrics) results_path = output_dir / cfg.output.filename results_path.parent.mkdir(parents=True, exist_ok=True) with results_path.open("a") as f: f.write("\n") # separate from previous runs f.write("==== CONFIG ====\n") f.write(OmegaConf.to_yaml(cfg)) f.write("\n") f.write("==== RESULTS ====\n") f.write(f"metrics: {metrics}\n") f.write(f"evaluation_time: {evaluation_time} seconds\n") f.write(f"inference_precision: {inference_precision}\n") f.write(f"inference_compile_target: {compile_target}\n") if compile_target != "disabled": f.write(f"inference_compile_mode: {compile_mode}\n") if profile_cfg["enabled"]: f.write(f"profile_dir: {profile_dir}\n") if profile_summary_path is not None: f.write(f"profile_summary: {profile_summary_path}\n") if __name__ == "__main__": run()