多机
This commit is contained in:
206
eval.py
206
eval.py
@@ -96,6 +96,46 @@ def get_compile_warmup_cfg(cfg):
|
||||
return warmup_cfg
|
||||
|
||||
|
||||
def get_preload_wait_cfg(cfg):
|
||||
preload_cfg = {
|
||||
"enabled": False,
|
||||
"file": "/tmp/lewm_preload_start",
|
||||
"poll_interval": 1.0,
|
||||
}
|
||||
cfg_preload = cfg.get("preload_wait")
|
||||
if cfg_preload is not None:
|
||||
preload_cfg.update(OmegaConf.to_container(cfg_preload, resolve=True))
|
||||
return preload_cfg
|
||||
|
||||
|
||||
def wait_for_preload_signal(cfg, rank=0):
|
||||
preload_cfg = get_preload_wait_cfg(cfg)
|
||||
if not preload_cfg["enabled"]:
|
||||
return
|
||||
|
||||
dist_ready = (
|
||||
torch.distributed.is_available()
|
||||
and torch.distributed.is_initialized()
|
||||
)
|
||||
if dist_ready:
|
||||
torch.distributed.barrier()
|
||||
|
||||
signal_path = Path(str(preload_cfg["file"])).expanduser()
|
||||
poll_interval = float(preload_cfg["poll_interval"])
|
||||
if rank == 0:
|
||||
print(
|
||||
"Preload ready. Create this file to start evaluation: "
|
||||
f"{signal_path}",
|
||||
flush=True,
|
||||
)
|
||||
while not signal_path.exists():
|
||||
time.sleep(poll_interval)
|
||||
print("Preload start signal received. Starting evaluation.", flush=True)
|
||||
|
||||
if dist_ready:
|
||||
torch.distributed.barrier()
|
||||
|
||||
|
||||
def maybe_compile_inference_target(model, cfg, device):
|
||||
compile_cfg = get_compile_cfg(cfg)
|
||||
compile_target = "disabled"
|
||||
@@ -229,6 +269,55 @@ def get_multi_gpu_cfg(cfg):
|
||||
return multi_gpu_cfg
|
||||
|
||||
|
||||
def get_multi_node_cfg(cfg):
|
||||
multi_node_cfg = {
|
||||
"enabled": False,
|
||||
"backend": "gloo",
|
||||
"rank_env": "RANK",
|
||||
"world_size_env": "WORLD_SIZE",
|
||||
"local_rank_env": "LOCAL_RANK",
|
||||
"output_mode": "single",
|
||||
}
|
||||
cfg_multi_node = cfg.get("multi_node")
|
||||
if cfg_multi_node is not None:
|
||||
multi_node_cfg.update(OmegaConf.to_container(cfg_multi_node, resolve=True))
|
||||
return multi_node_cfg
|
||||
|
||||
|
||||
def get_dist_env(name, default=None):
|
||||
value = os.environ.get(name, default)
|
||||
if value is None:
|
||||
return None
|
||||
return int(value)
|
||||
|
||||
|
||||
def get_rank_context(cfg):
|
||||
multi_node_cfg = get_multi_node_cfg(cfg)
|
||||
if not multi_node_cfg["enabled"]:
|
||||
return 0, 1, 0
|
||||
|
||||
rank = get_dist_env(multi_node_cfg["rank_env"])
|
||||
world_size = get_dist_env(multi_node_cfg["world_size_env"])
|
||||
local_rank = get_dist_env(multi_node_cfg["local_rank_env"], 0)
|
||||
|
||||
if rank is None or world_size is None:
|
||||
raise ValueError(
|
||||
"multi_node.enabled=true requires torchrun env vars RANK and WORLD_SIZE"
|
||||
)
|
||||
if world_size < 1:
|
||||
raise ValueError("WORLD_SIZE must be >= 1")
|
||||
if rank < 0 or rank >= world_size:
|
||||
raise ValueError("RANK must be in [0, WORLD_SIZE)")
|
||||
return rank, world_size, local_rank
|
||||
|
||||
|
||||
def all_gather_eval_result(result):
|
||||
world_size = torch.distributed.get_world_size()
|
||||
payload = [None for _ in range(world_size)]
|
||||
torch.distributed.all_gather_object(payload, result)
|
||||
return payload
|
||||
|
||||
|
||||
def build_process(cfg, dataset):
|
||||
process = {}
|
||||
for col in cfg.dataset.keys_to_cache:
|
||||
@@ -311,6 +400,22 @@ def shard_eval_cases(eval_episodes, eval_start_idx, num_shards):
|
||||
return shards
|
||||
|
||||
|
||||
def get_rank_eval_subset(eval_episodes, eval_start_idx, rank, world_size):
|
||||
if world_size < 1:
|
||||
raise ValueError("world_size must be >= 1")
|
||||
if rank < 0 or rank >= world_size:
|
||||
raise ValueError("rank must be in [0, world_size)")
|
||||
|
||||
total = len(eval_episodes)
|
||||
shard_sizes = [total // world_size] * world_size
|
||||
for idx in range(total % world_size):
|
||||
shard_sizes[idx] += 1
|
||||
|
||||
start = sum(shard_sizes[:rank])
|
||||
end = start + shard_sizes[rank]
|
||||
return eval_episodes[start:end], eval_start_idx[start:end]
|
||||
|
||||
|
||||
def run_eval_subset(
|
||||
cfg: DictConfig,
|
||||
eval_episodes,
|
||||
@@ -319,6 +424,7 @@ def run_eval_subset(
|
||||
*,
|
||||
device_override: str | None = None,
|
||||
enable_profile: bool = True,
|
||||
before_evaluate=None,
|
||||
):
|
||||
local_cfg = OmegaConf.create(OmegaConf.to_container(cfg, resolve=False))
|
||||
local_cfg.eval.num_eval = len(eval_episodes)
|
||||
@@ -376,6 +482,11 @@ def run_eval_subset(
|
||||
if str(device).startswith("cuda") and torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
|
||||
if before_evaluate is not None:
|
||||
before_evaluate()
|
||||
if str(device).startswith("cuda") and torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def evaluate_subset(episodes, start_indices, *, eval_cfg=local_cfg):
|
||||
return world.evaluate_from_dataset(
|
||||
dataset,
|
||||
@@ -421,6 +532,15 @@ def maybe_run_compile_warmup(cfg, eval_episodes, eval_start_idx):
|
||||
print("Skipping compile warmup because multi_gpu.enabled=true uses spawned workers.")
|
||||
return
|
||||
|
||||
if get_multi_node_cfg(cfg)["enabled"]:
|
||||
rank, world_size, local_rank = get_rank_context(cfg)
|
||||
eval_episodes, eval_start_idx = get_rank_eval_subset(
|
||||
eval_episodes, eval_start_idx, rank, world_size
|
||||
)
|
||||
device_override = f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu"
|
||||
else:
|
||||
device_override = None
|
||||
|
||||
warmup_count = min(int(warmup_cfg["num_eval"]), len(eval_episodes))
|
||||
if warmup_count < 1:
|
||||
return
|
||||
@@ -439,6 +559,7 @@ def maybe_run_compile_warmup(cfg, eval_episodes, eval_start_idx):
|
||||
eval_episodes[:warmup_count].tolist(),
|
||||
eval_start_idx[:warmup_count].tolist(),
|
||||
Path(tmpdir),
|
||||
device_override=device_override,
|
||||
enable_profile=False,
|
||||
)
|
||||
|
||||
@@ -551,6 +672,82 @@ def run_multi_gpu_eval(cfg, eval_episodes, eval_start_idx, output_dir: Path):
|
||||
"profile_summary_path": None,
|
||||
}
|
||||
|
||||
|
||||
def combine_eval_results(ordered_results):
|
||||
episode_successes = np.concatenate(
|
||||
[
|
||||
np.asarray(result["metrics"]["episode_successes"], dtype=np.bool_)
|
||||
for result in ordered_results
|
||||
]
|
||||
)
|
||||
|
||||
seeds = None
|
||||
shard_seeds = [result["metrics"].get("seeds") for result in ordered_results]
|
||||
if all(seed is not None for seed in shard_seeds):
|
||||
seeds = np.concatenate(shard_seeds)
|
||||
|
||||
metrics = {
|
||||
"success_rate": float(np.sum(episode_successes)) / len(episode_successes) * 100.0,
|
||||
"episode_successes": episode_successes,
|
||||
"seeds": seeds,
|
||||
}
|
||||
reference = ordered_results[0]
|
||||
return metrics, reference
|
||||
|
||||
|
||||
def run_multi_node_eval(cfg, eval_episodes, eval_start_idx, output_dir: Path):
|
||||
rank, world_size, local_rank = get_rank_context(cfg)
|
||||
shard_episodes, shard_start_idx = get_rank_eval_subset(
|
||||
eval_episodes, eval_start_idx, rank, world_size
|
||||
)
|
||||
if len(shard_episodes) == 0:
|
||||
raise ValueError("No evaluation episodes assigned to this rank")
|
||||
|
||||
local_cfg = OmegaConf.create(OmegaConf.to_container(cfg, resolve=False))
|
||||
local_cfg.multi_node.enabled = False
|
||||
if local_cfg.get("multi_gpu") is None:
|
||||
local_cfg.multi_gpu = OmegaConf.create({"enabled": False})
|
||||
else:
|
||||
local_cfg.multi_gpu.enabled = False
|
||||
|
||||
device = f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu"
|
||||
preload_cfg = get_preload_wait_cfg(cfg)
|
||||
if preload_cfg["enabled"]:
|
||||
if not torch.distributed.is_available():
|
||||
raise RuntimeError("torch.distributed is required for preload_wait")
|
||||
if not torch.distributed.is_initialized():
|
||||
torch.distributed.init_process_group(backend=get_multi_node_cfg(cfg)["backend"])
|
||||
|
||||
result = run_eval_subset(
|
||||
local_cfg,
|
||||
list(shard_episodes),
|
||||
list(shard_start_idx),
|
||||
output_dir,
|
||||
device_override=device,
|
||||
enable_profile=False,
|
||||
before_evaluate=lambda: wait_for_preload_signal(cfg, rank=rank),
|
||||
)
|
||||
if not torch.distributed.is_available():
|
||||
raise RuntimeError("torch.distributed is required for multi-node evaluation")
|
||||
if not torch.distributed.is_initialized():
|
||||
torch.distributed.init_process_group(backend=get_multi_node_cfg(cfg)["backend"])
|
||||
|
||||
gathered = all_gather_eval_result(result)
|
||||
metrics, reference = combine_eval_results(gathered)
|
||||
combined = {
|
||||
"metrics": metrics,
|
||||
"evaluation_time": max(item["evaluation_time"] for item in gathered),
|
||||
"inference_precision": reference["inference_precision"],
|
||||
"compile_target": reference["compile_target"],
|
||||
"compile_mode": reference["compile_mode"],
|
||||
"profile_dir": None,
|
||||
"profile_summary_path": None,
|
||||
}
|
||||
torch.distributed.barrier()
|
||||
if rank != 0:
|
||||
return None
|
||||
return combined
|
||||
|
||||
@hydra.main(version_base=None, config_path="./config/eval", config_name="pusht")
|
||||
def run(cfg: DictConfig):
|
||||
"""Run evaluation of dinowm vs random policy."""
|
||||
@@ -565,7 +762,14 @@ def run(cfg: DictConfig):
|
||||
|
||||
maybe_run_compile_warmup(cfg, eval_episodes, eval_start_idx)
|
||||
|
||||
if get_multi_gpu_cfg(cfg)["enabled"]:
|
||||
if get_multi_node_cfg(cfg)["enabled"] and get_multi_gpu_cfg(cfg)["enabled"]:
|
||||
raise ValueError("multi_node.enabled and multi_gpu.enabled are mutually exclusive")
|
||||
|
||||
if get_multi_node_cfg(cfg)["enabled"]:
|
||||
eval_result = run_multi_node_eval(cfg, eval_episodes, eval_start_idx, output_dir)
|
||||
if eval_result is None:
|
||||
return
|
||||
elif get_multi_gpu_cfg(cfg)["enabled"]:
|
||||
if profile_cfg["enabled"]:
|
||||
raise ValueError("Profiling is not supported together with multi_gpu.enabled=true")
|
||||
eval_result = run_multi_gpu_eval(cfg, eval_episodes, eval_start_idx, output_dir)
|
||||
|
||||
Reference in New Issue
Block a user