From 113e5918994f555c44fc15cbd037b006b2836b0b Mon Sep 17 00:00:00 2001 From: qihuanye Date: Sun, 17 May 2026 20:49:33 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A4=9A=E6=9C=BA=E8=B0=83=E6=95=B4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 1 + config/eval/cube.yaml | 4 +++ config/eval/pusht.yaml | 4 +++ config/eval/reacher.yaml | 4 +++ config/eval/tworoom.yaml | 4 +++ eval.py | 76 +++++++++++++++++++++++++++++++++++----- 6 files changed, 84 insertions(+), 9 deletions(-) diff --git a/.gitignore b/.gitignore index 682f75a..a4a541a 100644 --- a/.gitignore +++ b/.gitignore @@ -31,3 +31,4 @@ eval_tmp_*.npy .DS_Store .idea/ .vscode/ +*.log diff --git a/config/eval/cube.yaml b/config/eval/cube.yaml index 81ae72c..6ab4e33 100644 --- a/config/eval/cube.yaml +++ b/config/eval/cube.yaml @@ -67,6 +67,10 @@ multi_node: rank_env: RANK world_size_env: WORLD_SIZE local_rank_env: LOCAL_RANK + aggregate_results: true + sync_before_return: false + destroy_process_group: true + shard_strategy: round_robin preload_wait: enabled: false diff --git a/config/eval/pusht.yaml b/config/eval/pusht.yaml index b971486..a04838f 100644 --- a/config/eval/pusht.yaml +++ b/config/eval/pusht.yaml @@ -55,6 +55,10 @@ multi_node: rank_env: RANK world_size_env: WORLD_SIZE local_rank_env: LOCAL_RANK + aggregate_results: true + sync_before_return: false + destroy_process_group: true + shard_strategy: round_robin preload_wait: enabled: false diff --git a/config/eval/reacher.yaml b/config/eval/reacher.yaml index aac4e64..90f68a2 100644 --- a/config/eval/reacher.yaml +++ b/config/eval/reacher.yaml @@ -56,6 +56,10 @@ multi_node: rank_env: RANK world_size_env: WORLD_SIZE local_rank_env: LOCAL_RANK + aggregate_results: true + sync_before_return: false + destroy_process_group: true + shard_strategy: round_robin preload_wait: enabled: false diff --git a/config/eval/tworoom.yaml b/config/eval/tworoom.yaml index 12ede48..2a952e7 100644 --- a/config/eval/tworoom.yaml +++ b/config/eval/tworoom.yaml @@ -54,6 +54,10 @@ multi_node: rank_env: RANK world_size_env: WORLD_SIZE local_rank_env: LOCAL_RANK + aggregate_results: true + sync_before_return: false + destroy_process_group: true + shard_strategy: round_robin preload_wait: enabled: false diff --git a/eval.py b/eval.py index 1479cad..75fc2c3 100644 --- a/eval.py +++ b/eval.py @@ -277,6 +277,10 @@ def get_multi_node_cfg(cfg): "world_size_env": "WORLD_SIZE", "local_rank_env": "LOCAL_RANK", "output_mode": "single", + "aggregate_results": True, + "sync_before_return": False, + "destroy_process_group": True, + "shard_strategy": "round_robin", } cfg_multi_node = cfg.get("multi_node") if cfg_multi_node is not None: @@ -318,6 +322,28 @@ def all_gather_eval_result(result): return payload +def finalize_multi_node_process_group(cfg): + multi_node_cfg = get_multi_node_cfg(cfg) + if not multi_node_cfg["destroy_process_group"]: + return + if torch.distributed.is_available() and torch.distributed.is_initialized(): + torch.distributed.destroy_process_group() + + +def get_rank_result_path(output_dir: Path, cfg: DictConfig, rank: int) -> Path: + filename = str(cfg.output.filename) + if rank == 0: + return output_dir / filename + + suffix = Path(filename).suffix + stem = Path(filename).stem + if suffix: + ranked_filename = f"{stem}.rank{rank}{suffix}" + else: + ranked_filename = f"{filename}.rank{rank}" + return output_dir / ranked_filename + + def build_process(cfg, dataset): process = {} for col in cfg.dataset.keys_to_cache: @@ -400,12 +426,26 @@ def shard_eval_cases(eval_episodes, eval_start_idx, num_shards): return shards -def get_rank_eval_subset(eval_episodes, eval_start_idx, rank, world_size): +def get_rank_eval_subset( + eval_episodes, + eval_start_idx, + rank, + world_size, + *, + strategy="contiguous", +): if world_size < 1: raise ValueError("world_size must be >= 1") if rank < 0 or rank >= world_size: raise ValueError("rank must be in [0, world_size)") + if strategy == "round_robin": + episode_subset = eval_episodes[rank::world_size] + start_subset = eval_start_idx[rank::world_size] + return episode_subset, start_subset + if strategy != "contiguous": + raise ValueError("strategy must be one of: contiguous, round_robin") + total = len(eval_episodes) shard_sizes = [total // world_size] * world_size for idx in range(total % world_size): @@ -697,8 +737,13 @@ def combine_eval_results(ordered_results): def run_multi_node_eval(cfg, eval_episodes, eval_start_idx, output_dir: Path): rank, world_size, local_rank = get_rank_context(cfg) + multi_node_cfg = get_multi_node_cfg(cfg) shard_episodes, shard_start_idx = get_rank_eval_subset( - eval_episodes, eval_start_idx, rank, world_size + eval_episodes, + eval_start_idx, + rank, + world_size, + strategy=multi_node_cfg["shard_strategy"], ) if len(shard_episodes) == 0: raise ValueError("No evaluation episodes assigned to this rank") @@ -716,21 +761,27 @@ def run_multi_node_eval(cfg, eval_episodes, eval_start_idx, output_dir: Path): if not torch.distributed.is_available(): raise RuntimeError("torch.distributed is required for preload_wait") if not torch.distributed.is_initialized(): - torch.distributed.init_process_group(backend=get_multi_node_cfg(cfg)["backend"]) + torch.distributed.init_process_group(backend=multi_node_cfg["backend"]) + rank_output_path = get_rank_result_path(output_dir, cfg, rank) result = run_eval_subset( local_cfg, list(shard_episodes), list(shard_start_idx), - output_dir, + rank_output_path.parent, device_override=device, enable_profile=False, before_evaluate=lambda: wait_for_preload_signal(cfg, rank=rank), ) + if not multi_node_cfg["aggregate_results"]: + result["output_filename"] = rank_output_path.name + finalize_multi_node_process_group(cfg) + return result + if not torch.distributed.is_available(): raise RuntimeError("torch.distributed is required for multi-node evaluation") if not torch.distributed.is_initialized(): - torch.distributed.init_process_group(backend=get_multi_node_cfg(cfg)["backend"]) + torch.distributed.init_process_group(backend=multi_node_cfg["backend"]) gathered = all_gather_eval_result(result) metrics, reference = combine_eval_results(gathered) @@ -742,8 +793,11 @@ def run_multi_node_eval(cfg, eval_episodes, eval_start_idx, output_dir: Path): "compile_mode": reference["compile_mode"], "profile_dir": None, "profile_summary_path": None, + "output_filename": cfg.output.filename, } - torch.distributed.barrier() + if multi_node_cfg["sync_before_return"]: + torch.distributed.barrier() + finalize_multi_node_process_group(cfg) if rank != 0: return None return combined @@ -761,6 +815,7 @@ def run(cfg: DictConfig): profile_cfg = get_profile_cfg(cfg) maybe_run_compile_warmup(cfg, eval_episodes, eval_start_idx) + eval_wall_start = time.time() if get_multi_node_cfg(cfg)["enabled"] and get_multi_gpu_cfg(cfg)["enabled"]: raise ValueError("multi_node.enabled and multi_gpu.enabled are mutually exclusive") @@ -788,10 +843,11 @@ def run(cfg: DictConfig): compile_mode = eval_result["compile_mode"] profile_dir = eval_result["profile_dir"] profile_summary_path = eval_result["profile_summary_path"] + output_filename = eval_result.get("output_filename", cfg.output.filename) print(metrics) - results_path = output_dir / cfg.output.filename + results_path = output_dir / output_filename results_path.parent.mkdir(parents=True, exist_ok=True) with results_path.open("a") as f: @@ -810,8 +866,10 @@ def run(cfg: DictConfig): f.write(f"inference_compile_mode: {compile_mode}\n") if profile_cfg["enabled"]: f.write(f"profile_dir: {profile_dir}\n") - if profile_summary_path is not None: - f.write(f"profile_summary: {profile_summary_path}\n") + if profile_summary_path is not None: + f.write(f"profile_summary: {profile_summary_path}\n") + + f.write(f"total_wall_time: {time.time() - eval_wall_start} seconds\n") if __name__ == "__main__":