From 0164e21f48d55460391ed67f2f63ad333350a3fa Mon Sep 17 00:00:00 2001 From: qihuanye Date: Sun, 17 May 2026 19:23:31 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A4=9A=E6=9C=BA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config/eval/cube.yaml | 12 ++ config/eval/pusht.yaml | 12 ++ config/eval/reacher.yaml | 12 ++ config/eval/solver/cem.yaml | 2 +- config/eval/tworoom.yaml | 12 ++ eval.py | 206 ++++++++++++++++++++++++- scripts/launch_multinode_eval.sh | 255 +++++++++++++++++++++++++++++++ scripts/warmup_eval.sh | 22 ++- 8 files changed, 527 insertions(+), 6 deletions(-) create mode 100755 scripts/launch_multinode_eval.sh diff --git a/config/eval/cube.yaml b/config/eval/cube.yaml index 32e313f..81ae72c 100644 --- a/config/eval/cube.yaml +++ b/config/eval/cube.yaml @@ -61,5 +61,17 @@ eval: target_quat: value: goal_privileged_block_0_quat +multi_node: + enabled: false + backend: gloo + rank_env: RANK + world_size_env: WORLD_SIZE + local_rank_env: LOCAL_RANK + +preload_wait: + enabled: false + file: /tmp/lewm_preload_start + poll_interval: 1.0 + output: filename: ogb_cube_results.txt diff --git a/config/eval/pusht.yaml b/config/eval/pusht.yaml index d7f3930..b971486 100644 --- a/config/eval/pusht.yaml +++ b/config/eval/pusht.yaml @@ -48,6 +48,18 @@ eval: args: goal_state: value: goal_state + +multi_node: + enabled: false + backend: gloo + rank_env: RANK + world_size_env: WORLD_SIZE + local_rank_env: LOCAL_RANK + +preload_wait: + enabled: false + file: /tmp/lewm_preload_start + poll_interval: 1.0 output: filename: pusht_results.txt diff --git a/config/eval/reacher.yaml b/config/eval/reacher.yaml index 89c0c92..aac4e64 100644 --- a/config/eval/reacher.yaml +++ b/config/eval/reacher.yaml @@ -50,5 +50,17 @@ eval: target_qpos: value: goal_qpos +multi_node: + enabled: false + backend: gloo + rank_env: RANK + world_size_env: WORLD_SIZE + local_rank_env: LOCAL_RANK + +preload_wait: + enabled: false + file: /tmp/lewm_preload_start + poll_interval: 1.0 + output: filename: dmc_results.txt diff --git a/config/eval/solver/cem.yaml b/config/eval/solver/cem.yaml index 7830f78..add68d8 100644 --- a/config/eval/solver/cem.yaml +++ b/config/eval/solver/cem.yaml @@ -1,6 +1,6 @@ _target_: stable_worldmodel.solver.CEMSolver model: ??? -batch_size: 8 +batch_size: 16 # Original defaults: num_samples=300, n_steps=30, topk=30, batch_size=8. num_samples: 64 var_scale: 1.0 diff --git a/config/eval/tworoom.yaml b/config/eval/tworoom.yaml index 9cfec49..12ede48 100644 --- a/config/eval/tworoom.yaml +++ b/config/eval/tworoom.yaml @@ -48,5 +48,17 @@ eval: goal_state: value: goal_proprio +multi_node: + enabled: false + backend: gloo + rank_env: RANK + world_size_env: WORLD_SIZE + local_rank_env: LOCAL_RANK + +preload_wait: + enabled: false + file: /tmp/lewm_preload_start + poll_interval: 1.0 + output: filename: tworoom_results.txt diff --git a/eval.py b/eval.py index ae04df1..1479cad 100644 --- a/eval.py +++ b/eval.py @@ -96,6 +96,46 @@ def get_compile_warmup_cfg(cfg): return warmup_cfg +def get_preload_wait_cfg(cfg): + preload_cfg = { + "enabled": False, + "file": "/tmp/lewm_preload_start", + "poll_interval": 1.0, + } + cfg_preload = cfg.get("preload_wait") + if cfg_preload is not None: + preload_cfg.update(OmegaConf.to_container(cfg_preload, resolve=True)) + return preload_cfg + + +def wait_for_preload_signal(cfg, rank=0): + preload_cfg = get_preload_wait_cfg(cfg) + if not preload_cfg["enabled"]: + return + + dist_ready = ( + torch.distributed.is_available() + and torch.distributed.is_initialized() + ) + if dist_ready: + torch.distributed.barrier() + + signal_path = Path(str(preload_cfg["file"])).expanduser() + poll_interval = float(preload_cfg["poll_interval"]) + if rank == 0: + print( + "Preload ready. Create this file to start evaluation: " + f"{signal_path}", + flush=True, + ) + while not signal_path.exists(): + time.sleep(poll_interval) + print("Preload start signal received. Starting evaluation.", flush=True) + + if dist_ready: + torch.distributed.barrier() + + def maybe_compile_inference_target(model, cfg, device): compile_cfg = get_compile_cfg(cfg) compile_target = "disabled" @@ -229,6 +269,55 @@ def get_multi_gpu_cfg(cfg): return multi_gpu_cfg +def get_multi_node_cfg(cfg): + multi_node_cfg = { + "enabled": False, + "backend": "gloo", + "rank_env": "RANK", + "world_size_env": "WORLD_SIZE", + "local_rank_env": "LOCAL_RANK", + "output_mode": "single", + } + cfg_multi_node = cfg.get("multi_node") + if cfg_multi_node is not None: + multi_node_cfg.update(OmegaConf.to_container(cfg_multi_node, resolve=True)) + return multi_node_cfg + + +def get_dist_env(name, default=None): + value = os.environ.get(name, default) + if value is None: + return None + return int(value) + + +def get_rank_context(cfg): + multi_node_cfg = get_multi_node_cfg(cfg) + if not multi_node_cfg["enabled"]: + return 0, 1, 0 + + rank = get_dist_env(multi_node_cfg["rank_env"]) + world_size = get_dist_env(multi_node_cfg["world_size_env"]) + local_rank = get_dist_env(multi_node_cfg["local_rank_env"], 0) + + if rank is None or world_size is None: + raise ValueError( + "multi_node.enabled=true requires torchrun env vars RANK and WORLD_SIZE" + ) + if world_size < 1: + raise ValueError("WORLD_SIZE must be >= 1") + if rank < 0 or rank >= world_size: + raise ValueError("RANK must be in [0, WORLD_SIZE)") + return rank, world_size, local_rank + + +def all_gather_eval_result(result): + world_size = torch.distributed.get_world_size() + payload = [None for _ in range(world_size)] + torch.distributed.all_gather_object(payload, result) + return payload + + def build_process(cfg, dataset): process = {} for col in cfg.dataset.keys_to_cache: @@ -311,6 +400,22 @@ def shard_eval_cases(eval_episodes, eval_start_idx, num_shards): return shards +def get_rank_eval_subset(eval_episodes, eval_start_idx, rank, world_size): + if world_size < 1: + raise ValueError("world_size must be >= 1") + if rank < 0 or rank >= world_size: + raise ValueError("rank must be in [0, world_size)") + + total = len(eval_episodes) + shard_sizes = [total // world_size] * world_size + for idx in range(total % world_size): + shard_sizes[idx] += 1 + + start = sum(shard_sizes[:rank]) + end = start + shard_sizes[rank] + return eval_episodes[start:end], eval_start_idx[start:end] + + def run_eval_subset( cfg: DictConfig, eval_episodes, @@ -319,6 +424,7 @@ def run_eval_subset( *, device_override: str | None = None, enable_profile: bool = True, + before_evaluate=None, ): local_cfg = OmegaConf.create(OmegaConf.to_container(cfg, resolve=False)) local_cfg.eval.num_eval = len(eval_episodes) @@ -376,6 +482,11 @@ def run_eval_subset( if str(device).startswith("cuda") and torch.cuda.is_available(): torch.cuda.synchronize() + if before_evaluate is not None: + before_evaluate() + if str(device).startswith("cuda") and torch.cuda.is_available(): + torch.cuda.synchronize() + def evaluate_subset(episodes, start_indices, *, eval_cfg=local_cfg): return world.evaluate_from_dataset( dataset, @@ -421,6 +532,15 @@ def maybe_run_compile_warmup(cfg, eval_episodes, eval_start_idx): print("Skipping compile warmup because multi_gpu.enabled=true uses spawned workers.") return + if get_multi_node_cfg(cfg)["enabled"]: + rank, world_size, local_rank = get_rank_context(cfg) + eval_episodes, eval_start_idx = get_rank_eval_subset( + eval_episodes, eval_start_idx, rank, world_size + ) + device_override = f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu" + else: + device_override = None + warmup_count = min(int(warmup_cfg["num_eval"]), len(eval_episodes)) if warmup_count < 1: return @@ -439,6 +559,7 @@ def maybe_run_compile_warmup(cfg, eval_episodes, eval_start_idx): eval_episodes[:warmup_count].tolist(), eval_start_idx[:warmup_count].tolist(), Path(tmpdir), + device_override=device_override, enable_profile=False, ) @@ -551,6 +672,82 @@ def run_multi_gpu_eval(cfg, eval_episodes, eval_start_idx, output_dir: Path): "profile_summary_path": None, } + +def combine_eval_results(ordered_results): + episode_successes = np.concatenate( + [ + np.asarray(result["metrics"]["episode_successes"], dtype=np.bool_) + for result in ordered_results + ] + ) + + seeds = None + shard_seeds = [result["metrics"].get("seeds") for result in ordered_results] + if all(seed is not None for seed in shard_seeds): + seeds = np.concatenate(shard_seeds) + + metrics = { + "success_rate": float(np.sum(episode_successes)) / len(episode_successes) * 100.0, + "episode_successes": episode_successes, + "seeds": seeds, + } + reference = ordered_results[0] + return metrics, reference + + +def run_multi_node_eval(cfg, eval_episodes, eval_start_idx, output_dir: Path): + rank, world_size, local_rank = get_rank_context(cfg) + shard_episodes, shard_start_idx = get_rank_eval_subset( + eval_episodes, eval_start_idx, rank, world_size + ) + if len(shard_episodes) == 0: + raise ValueError("No evaluation episodes assigned to this rank") + + local_cfg = OmegaConf.create(OmegaConf.to_container(cfg, resolve=False)) + local_cfg.multi_node.enabled = False + if local_cfg.get("multi_gpu") is None: + local_cfg.multi_gpu = OmegaConf.create({"enabled": False}) + else: + local_cfg.multi_gpu.enabled = False + + device = f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu" + preload_cfg = get_preload_wait_cfg(cfg) + if preload_cfg["enabled"]: + if not torch.distributed.is_available(): + raise RuntimeError("torch.distributed is required for preload_wait") + if not torch.distributed.is_initialized(): + torch.distributed.init_process_group(backend=get_multi_node_cfg(cfg)["backend"]) + + result = run_eval_subset( + local_cfg, + list(shard_episodes), + list(shard_start_idx), + output_dir, + device_override=device, + enable_profile=False, + before_evaluate=lambda: wait_for_preload_signal(cfg, rank=rank), + ) + if not torch.distributed.is_available(): + raise RuntimeError("torch.distributed is required for multi-node evaluation") + if not torch.distributed.is_initialized(): + torch.distributed.init_process_group(backend=get_multi_node_cfg(cfg)["backend"]) + + gathered = all_gather_eval_result(result) + metrics, reference = combine_eval_results(gathered) + combined = { + "metrics": metrics, + "evaluation_time": max(item["evaluation_time"] for item in gathered), + "inference_precision": reference["inference_precision"], + "compile_target": reference["compile_target"], + "compile_mode": reference["compile_mode"], + "profile_dir": None, + "profile_summary_path": None, + } + torch.distributed.barrier() + if rank != 0: + return None + return combined + @hydra.main(version_base=None, config_path="./config/eval", config_name="pusht") def run(cfg: DictConfig): """Run evaluation of dinowm vs random policy.""" @@ -565,7 +762,14 @@ def run(cfg: DictConfig): maybe_run_compile_warmup(cfg, eval_episodes, eval_start_idx) - if get_multi_gpu_cfg(cfg)["enabled"]: + if get_multi_node_cfg(cfg)["enabled"] and get_multi_gpu_cfg(cfg)["enabled"]: + raise ValueError("multi_node.enabled and multi_gpu.enabled are mutually exclusive") + + if get_multi_node_cfg(cfg)["enabled"]: + eval_result = run_multi_node_eval(cfg, eval_episodes, eval_start_idx, output_dir) + if eval_result is None: + return + elif get_multi_gpu_cfg(cfg)["enabled"]: if profile_cfg["enabled"]: raise ValueError("Profiling is not supported together with multi_gpu.enabled=true") eval_result = run_multi_gpu_eval(cfg, eval_episodes, eval_start_idx, output_dir) diff --git a/scripts/launch_multinode_eval.sh b/scripts/launch_multinode_eval.sh new file mode 100755 index 0000000..2660354 --- /dev/null +++ b/scripts/launch_multinode_eval.sh @@ -0,0 +1,255 @@ +#!/usr/bin/env bash +set -euo pipefail + +# Launch 2-node LeWM evaluation from node-3. +# +# Defaults match the current cluster layout: +# node-3: 10.16.200.9, node_rank=0 +# node-2: 10.16.200.8, node_rank=1 +# Each node runs two local torchrun processes for two visible GPUs. + +REPO_ROOT="${REPO_ROOT:-/home/lewm/lewm}" +REMOTE_HOST="${REMOTE_HOST:-lewm@10.16.200.8}" +MASTER_ADDR="${MASTER_ADDR:-10.16.200.9}" +MASTER_PORT="${MASTER_PORT:-29500}" + +NNODES="${NNODES:-2}" +NPROC_PER_NODE="${NPROC_PER_NODE:-2}" +CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-0,1}" +STABLEWM_HOME="${STABLEWM_HOME:-/home/lewm/.stable-wm}" + +CONFIG_NAME="${CONFIG_NAME:-pusht.yaml}" +POLICY="${POLICY:-pusht/lewm}" +OUTPUT_FILENAME="${OUTPUT_FILENAME:-pusht_multinode_results.txt}" +EXTRA_ARGS="${EXTRA_ARGS:-}" +DRY_RUN="${DRY_RUN:-0}" +TAIL_LOGS="${TAIL_LOGS:-1}" +PRELOAD_WAIT="${PRELOAD_WAIT:-0}" +PRELOAD_SIGNAL_FILE="${PRELOAD_SIGNAL_FILE:-/tmp/lewm_preload_start}" +PRELOAD_CLEAR_SIGNAL="${PRELOAD_CLEAR_SIGNAL:-1}" + +LOG_DIR="${LOG_DIR:-${REPO_ROOT}/logs/multinode}" +mkdir -p "${LOG_DIR}" +RUN_ID="$(date +%Y%m%d_%H%M%S)" +LOCAL_LOG="${LOG_DIR}/${RUN_ID}_node3_rank0.log" +REMOTE_LOG="${LOG_DIR}/${RUN_ID}_node2_rank1.log" + +SSH_OPTS=( + -F /dev/null + -o StrictHostKeyChecking=no + -o ServerAliveInterval=30 + -o ServerAliveCountMax=20 +) + +COMMON_ARGS=( + "--config-name=${CONFIG_NAME}" + "policy=${POLICY}" + "multi_node.enabled=true" + "output.filename=${OUTPUT_FILENAME}" +) + +if [[ "${PRELOAD_WAIT}" == "1" ]]; then + COMMON_ARGS+=( + "preload_wait.enabled=true" + "preload_wait.file=${PRELOAD_SIGNAL_FILE}" + ) +fi + +if [[ -n "${EXTRA_ARGS}" ]]; then + # shellcheck disable=SC2206 + COMMON_ARGS+=(${EXTRA_ARGS}) +fi + +make_command() { + local node_rank="$1" + local repo_q cuda_q stablewm_q arg_q eval_args + printf -v repo_q '%q' "${REPO_ROOT}" + printf -v cuda_q '%q' "${CUDA_VISIBLE_DEVICES}" + printf -v stablewm_q '%q' "${STABLEWM_HOME}" + + eval_args="" + for arg in "${COMMON_ARGS[@]}"; do + printf -v arg_q '%q' "${arg}" + eval_args+=" ${arg_q}" + done + + printf 'cd %s && source .venv/bin/activate && export CUDA_VISIBLE_DEVICES=%s && export STABLEWM_HOME=%s && torchrun --nnodes=%q --nproc_per_node=%q --node_rank=%q --master_addr=%q --master_port=%q eval.py%s' \ + "${repo_q}" \ + "${cuda_q}" \ + "${stablewm_q}" \ + "${NNODES}" \ + "${NPROC_PER_NODE}" \ + "${node_rank}" \ + "${MASTER_ADDR}" \ + "${MASTER_PORT}" \ + "${eval_args}" +} + +REMOTE_CMD="$(make_command 1)" +LOCAL_CMD="$(make_command 0)" +printf -v REMOTE_CMD_Q '%q' "${REMOTE_CMD}" + +REMOTE_PID="" +LOCAL_PID="" +LOCAL_TAIL_PID="" +REMOTE_TAIL_PID="" +REMOTE_CLEANUP_CMD="" +REMOTE_CLEANUP_CMD_Q="" + +start_log_tail() { + local label="$1" + local log_file="$2" + local label_q log_q + + printf -v label_q '%q' "${label}" + printf -v log_q '%q' "${log_file}" + setsid bash -lc "tail -n +1 -F ${log_q} 2>/dev/null | sed -u 's/^/[${label_q}] /'" & +} + +stop_log_tails() { + local pid + for pid in "${LOCAL_TAIL_PID}" "${REMOTE_TAIL_PID}"; do + if [[ -n "${pid}" ]] && kill -0 "${pid}" 2>/dev/null; then + kill -TERM "-${pid}" 2>/dev/null || kill -TERM "${pid}" 2>/dev/null || true + fi + done +} + +remote_cleanup_command() { + local pattern_q + local patterns=( + "torchrun .*--master_addr=${MASTER_ADDR} .*--master_port=${MASTER_PORT} .*eval.py" + "torchrun .*--master_port=${MASTER_PORT} .*eval.py" + "python.*eval.py .*output.filename=${OUTPUT_FILENAME}" + ) + + printf 'set +e; ' + for pattern in "${patterns[@]}"; do + printf -v pattern_q '%q' "${pattern}" + printf 'pkill -TERM -f %s 2>/dev/null; ' "${pattern_q}" + done + printf 'sleep 2; ' + for pattern in "${patterns[@]}"; do + printf -v pattern_q '%q' "${pattern}" + printf 'pkill -KILL -f %s 2>/dev/null; ' "${pattern_q}" + done + printf 'true' +} + +cleanup() { + local status="$?" + trap - INT TERM EXIT + + if [[ "${status}" -eq 0 ]]; then + return 0 + fi + + echo + echo "Stopping multi-node eval..." + stop_log_tails + + if [[ -n "${LOCAL_PID}" ]] && kill -0 "${LOCAL_PID}" 2>/dev/null; then + kill -TERM "-${LOCAL_PID}" 2>/dev/null || kill -TERM "${LOCAL_PID}" 2>/dev/null || true + fi + + if [[ -n "${REMOTE_PID}" ]] && kill -0 "${REMOTE_PID}" 2>/dev/null; then + kill -TERM "${REMOTE_PID}" 2>/dev/null || true + fi + + ssh "${SSH_OPTS[@]}" "${REMOTE_HOST}" "bash -lc ${REMOTE_CLEANUP_CMD_Q}" >/dev/null 2>&1 || true + + if [[ -n "${LOCAL_PID}" ]] && kill -0 "${LOCAL_PID}" 2>/dev/null; then + sleep 2 + kill -KILL "-${LOCAL_PID}" 2>/dev/null || kill -KILL "${LOCAL_PID}" 2>/dev/null || true + fi + + echo "Cleanup requested. Check logs if any process was already exiting:" + echo " local: ${LOCAL_LOG}" + echo " remote: ${REMOTE_LOG}" + exit "${status}" +} + +trap cleanup INT TERM EXIT + +REMOTE_CLEANUP_CMD="$(remote_cleanup_command)" +printf -v REMOTE_CLEANUP_CMD_Q '%q' "${REMOTE_CLEANUP_CMD}" + +echo "Launching multi-node eval" +echo " master: ${MASTER_ADDR}:${MASTER_PORT}" +echo " remote: ${REMOTE_HOST}" +echo " repo: ${REPO_ROOT}" +echo " stablewm: ${STABLEWM_HOME}" +echo " config: ${CONFIG_NAME}" +echo " policy: ${POLICY}" +echo " output: ${OUTPUT_FILENAME}" +echo " extra: ${EXTRA_ARGS:-}" +echo " tail logs: ${TAIL_LOGS}" +echo " preload wait: ${PRELOAD_WAIT}" +if [[ "${PRELOAD_WAIT}" == "1" ]]; then + echo " preload signal: ${PRELOAD_SIGNAL_FILE}" + echo " start command: touch ${PRELOAD_SIGNAL_FILE}" +fi +echo " local log: ${LOCAL_LOG}" +echo " remote log: ${REMOTE_LOG}" + +if [[ "${DRY_RUN}" == "1" ]]; then + echo + echo "Remote command:" + echo "ssh ${SSH_OPTS[*]} ${REMOTE_HOST} bash -lc ${REMOTE_CMD_Q}" + echo + echo "Local command:" + printf -v LOCAL_CMD_Q '%q' "${LOCAL_CMD}" + echo "bash -lc ${LOCAL_CMD_Q}" + exit 0 +fi + +if [[ "${PRELOAD_WAIT}" == "1" && "${PRELOAD_CLEAR_SIGNAL}" == "1" ]]; then + rm -f "${PRELOAD_SIGNAL_FILE}" +fi + +echo "Starting remote node_rank=1..." +ssh "${SSH_OPTS[@]}" "${REMOTE_HOST}" "bash -lc ${REMOTE_CMD_Q}" >"${REMOTE_LOG}" 2>&1 & +REMOTE_PID="$!" + +if [[ "${TAIL_LOGS}" == "1" ]]; then + start_log_tail "node2" "${REMOTE_LOG}" + REMOTE_TAIL_PID="$!" +fi + +sleep 3 + +echo "Starting local node_rank=0..." +set +e +setsid bash -lc "${LOCAL_CMD}" >"${LOCAL_LOG}" 2>&1 & +LOCAL_PID="$!" + +if [[ "${TAIL_LOGS}" == "1" ]]; then + start_log_tail "node3" "${LOCAL_LOG}" + LOCAL_TAIL_PID="$!" +fi + +wait "${LOCAL_PID}" +LOCAL_STATUS="$?" + +wait "${REMOTE_PID}" +REMOTE_STATUS="$?" +set -e + +stop_log_tails +trap - INT TERM EXIT + +echo "Local status: ${LOCAL_STATUS}" +echo "Remote status: ${REMOTE_STATUS}" +echo "Local log: ${LOCAL_LOG}" +echo "Remote log: ${REMOTE_LOG}" + +if [[ "${LOCAL_STATUS}" -ne 0 || "${REMOTE_STATUS}" -ne 0 ]]; then + echo "Multi-node eval failed. Tail logs:" + echo "===== local tail =====" + tail -80 "${LOCAL_LOG}" || true + echo "===== remote tail =====" + tail -80 "${REMOTE_LOG}" || true + exit 1 +fi + +echo "Multi-node eval complete." diff --git a/scripts/warmup_eval.sh b/scripts/warmup_eval.sh index 83f0cfa..071ba18 100755 --- a/scripts/warmup_eval.sh +++ b/scripts/warmup_eval.sh @@ -44,6 +44,13 @@ mkdir -p "${OUTPUT_DIR}" # ROCR_VISIBLE_DEVICES=2,3 HIP_VISIBLE_DEVICES=0,1 MULTI_GPU=1 MULTI_GPU_DEVICES='[0,1]' MULTI_GPU="${MULTI_GPU:-0}" MULTI_GPU_DEVICES="${MULTI_GPU_DEVICES:-[0,1]}" +MULTI_NODE="${MULTI_NODE:-0}" + +# Multi-node warmup uses the same eval.py entrypoint under torchrun. +# Example: +# torchrun --nnodes=2 --nproc_per_node=2 --node_rank=0 --master_addr= --master_port=29500 \ +# eval.py --config-name=pusht.yaml policy=pusht/lewm multi_node.enabled=true +# This script leaves multi-node launch to the caller. COMMON_ARGS=( "eval.num_eval=${WARMUP_NUM_EVAL}" @@ -57,6 +64,12 @@ if [[ "${MULTI_GPU}" == "1" ]]; then ) fi +if [[ "${MULTI_NODE}" == "1" ]]; then + COMMON_ARGS+=( + "multi_node.enabled=true" + ) +fi + run_warmup() { local config_name="$1" local policy="$2" @@ -80,10 +93,11 @@ echo " HIP_VISIBLE_DEVICES: ${HIP_VISIBLE_DEVICES}" echo " CUDA_VISIBLE_DEVICES: ${CUDA_VISIBLE_DEVICES}" echo " WARMUP_NUM_EVAL: ${WARMUP_NUM_EVAL}" echo " INFERENCE_PRECISION: ${INFERENCE_PRECISION}" -echo " MULTI_GPU: ${MULTI_GPU}" -if [[ "${MULTI_GPU}" == "1" ]]; then - echo " MULTI_GPU_DEVICES: ${MULTI_GPU_DEVICES}" -fi + echo " MULTI_GPU: ${MULTI_GPU}" + if [[ "${MULTI_GPU}" == "1" ]]; then + echo " MULTI_GPU_DEVICES: ${MULTI_GPU_DEVICES}" + fi + echo " MULTI_NODE: ${MULTI_NODE}" # Defaults match the checkpoint names used in this repo. If onsite checkpoint # folders differ, override by editing these calls or passing the equivalent