多机
This commit is contained in:
@@ -61,5 +61,17 @@ eval:
|
|||||||
target_quat:
|
target_quat:
|
||||||
value: goal_privileged_block_0_quat
|
value: goal_privileged_block_0_quat
|
||||||
|
|
||||||
|
multi_node:
|
||||||
|
enabled: false
|
||||||
|
backend: gloo
|
||||||
|
rank_env: RANK
|
||||||
|
world_size_env: WORLD_SIZE
|
||||||
|
local_rank_env: LOCAL_RANK
|
||||||
|
|
||||||
|
preload_wait:
|
||||||
|
enabled: false
|
||||||
|
file: /tmp/lewm_preload_start
|
||||||
|
poll_interval: 1.0
|
||||||
|
|
||||||
output:
|
output:
|
||||||
filename: ogb_cube_results.txt
|
filename: ogb_cube_results.txt
|
||||||
|
|||||||
@@ -48,6 +48,18 @@ eval:
|
|||||||
args:
|
args:
|
||||||
goal_state:
|
goal_state:
|
||||||
value: goal_state
|
value: goal_state
|
||||||
|
|
||||||
|
multi_node:
|
||||||
|
enabled: false
|
||||||
|
backend: gloo
|
||||||
|
rank_env: RANK
|
||||||
|
world_size_env: WORLD_SIZE
|
||||||
|
local_rank_env: LOCAL_RANK
|
||||||
|
|
||||||
|
preload_wait:
|
||||||
|
enabled: false
|
||||||
|
file: /tmp/lewm_preload_start
|
||||||
|
poll_interval: 1.0
|
||||||
|
|
||||||
output:
|
output:
|
||||||
filename: pusht_results.txt
|
filename: pusht_results.txt
|
||||||
|
|||||||
@@ -50,5 +50,17 @@ eval:
|
|||||||
target_qpos:
|
target_qpos:
|
||||||
value: goal_qpos
|
value: goal_qpos
|
||||||
|
|
||||||
|
multi_node:
|
||||||
|
enabled: false
|
||||||
|
backend: gloo
|
||||||
|
rank_env: RANK
|
||||||
|
world_size_env: WORLD_SIZE
|
||||||
|
local_rank_env: LOCAL_RANK
|
||||||
|
|
||||||
|
preload_wait:
|
||||||
|
enabled: false
|
||||||
|
file: /tmp/lewm_preload_start
|
||||||
|
poll_interval: 1.0
|
||||||
|
|
||||||
output:
|
output:
|
||||||
filename: dmc_results.txt
|
filename: dmc_results.txt
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
_target_: stable_worldmodel.solver.CEMSolver
|
_target_: stable_worldmodel.solver.CEMSolver
|
||||||
model: ???
|
model: ???
|
||||||
batch_size: 8
|
batch_size: 16
|
||||||
# Original defaults: num_samples=300, n_steps=30, topk=30, batch_size=8.
|
# Original defaults: num_samples=300, n_steps=30, topk=30, batch_size=8.
|
||||||
num_samples: 64
|
num_samples: 64
|
||||||
var_scale: 1.0
|
var_scale: 1.0
|
||||||
|
|||||||
@@ -48,5 +48,17 @@ eval:
|
|||||||
goal_state:
|
goal_state:
|
||||||
value: goal_proprio
|
value: goal_proprio
|
||||||
|
|
||||||
|
multi_node:
|
||||||
|
enabled: false
|
||||||
|
backend: gloo
|
||||||
|
rank_env: RANK
|
||||||
|
world_size_env: WORLD_SIZE
|
||||||
|
local_rank_env: LOCAL_RANK
|
||||||
|
|
||||||
|
preload_wait:
|
||||||
|
enabled: false
|
||||||
|
file: /tmp/lewm_preload_start
|
||||||
|
poll_interval: 1.0
|
||||||
|
|
||||||
output:
|
output:
|
||||||
filename: tworoom_results.txt
|
filename: tworoom_results.txt
|
||||||
|
|||||||
206
eval.py
206
eval.py
@@ -96,6 +96,46 @@ def get_compile_warmup_cfg(cfg):
|
|||||||
return warmup_cfg
|
return warmup_cfg
|
||||||
|
|
||||||
|
|
||||||
|
def get_preload_wait_cfg(cfg):
|
||||||
|
preload_cfg = {
|
||||||
|
"enabled": False,
|
||||||
|
"file": "/tmp/lewm_preload_start",
|
||||||
|
"poll_interval": 1.0,
|
||||||
|
}
|
||||||
|
cfg_preload = cfg.get("preload_wait")
|
||||||
|
if cfg_preload is not None:
|
||||||
|
preload_cfg.update(OmegaConf.to_container(cfg_preload, resolve=True))
|
||||||
|
return preload_cfg
|
||||||
|
|
||||||
|
|
||||||
|
def wait_for_preload_signal(cfg, rank=0):
|
||||||
|
preload_cfg = get_preload_wait_cfg(cfg)
|
||||||
|
if not preload_cfg["enabled"]:
|
||||||
|
return
|
||||||
|
|
||||||
|
dist_ready = (
|
||||||
|
torch.distributed.is_available()
|
||||||
|
and torch.distributed.is_initialized()
|
||||||
|
)
|
||||||
|
if dist_ready:
|
||||||
|
torch.distributed.barrier()
|
||||||
|
|
||||||
|
signal_path = Path(str(preload_cfg["file"])).expanduser()
|
||||||
|
poll_interval = float(preload_cfg["poll_interval"])
|
||||||
|
if rank == 0:
|
||||||
|
print(
|
||||||
|
"Preload ready. Create this file to start evaluation: "
|
||||||
|
f"{signal_path}",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
while not signal_path.exists():
|
||||||
|
time.sleep(poll_interval)
|
||||||
|
print("Preload start signal received. Starting evaluation.", flush=True)
|
||||||
|
|
||||||
|
if dist_ready:
|
||||||
|
torch.distributed.barrier()
|
||||||
|
|
||||||
|
|
||||||
def maybe_compile_inference_target(model, cfg, device):
|
def maybe_compile_inference_target(model, cfg, device):
|
||||||
compile_cfg = get_compile_cfg(cfg)
|
compile_cfg = get_compile_cfg(cfg)
|
||||||
compile_target = "disabled"
|
compile_target = "disabled"
|
||||||
@@ -229,6 +269,55 @@ def get_multi_gpu_cfg(cfg):
|
|||||||
return multi_gpu_cfg
|
return multi_gpu_cfg
|
||||||
|
|
||||||
|
|
||||||
|
def get_multi_node_cfg(cfg):
|
||||||
|
multi_node_cfg = {
|
||||||
|
"enabled": False,
|
||||||
|
"backend": "gloo",
|
||||||
|
"rank_env": "RANK",
|
||||||
|
"world_size_env": "WORLD_SIZE",
|
||||||
|
"local_rank_env": "LOCAL_RANK",
|
||||||
|
"output_mode": "single",
|
||||||
|
}
|
||||||
|
cfg_multi_node = cfg.get("multi_node")
|
||||||
|
if cfg_multi_node is not None:
|
||||||
|
multi_node_cfg.update(OmegaConf.to_container(cfg_multi_node, resolve=True))
|
||||||
|
return multi_node_cfg
|
||||||
|
|
||||||
|
|
||||||
|
def get_dist_env(name, default=None):
|
||||||
|
value = os.environ.get(name, default)
|
||||||
|
if value is None:
|
||||||
|
return None
|
||||||
|
return int(value)
|
||||||
|
|
||||||
|
|
||||||
|
def get_rank_context(cfg):
|
||||||
|
multi_node_cfg = get_multi_node_cfg(cfg)
|
||||||
|
if not multi_node_cfg["enabled"]:
|
||||||
|
return 0, 1, 0
|
||||||
|
|
||||||
|
rank = get_dist_env(multi_node_cfg["rank_env"])
|
||||||
|
world_size = get_dist_env(multi_node_cfg["world_size_env"])
|
||||||
|
local_rank = get_dist_env(multi_node_cfg["local_rank_env"], 0)
|
||||||
|
|
||||||
|
if rank is None or world_size is None:
|
||||||
|
raise ValueError(
|
||||||
|
"multi_node.enabled=true requires torchrun env vars RANK and WORLD_SIZE"
|
||||||
|
)
|
||||||
|
if world_size < 1:
|
||||||
|
raise ValueError("WORLD_SIZE must be >= 1")
|
||||||
|
if rank < 0 or rank >= world_size:
|
||||||
|
raise ValueError("RANK must be in [0, WORLD_SIZE)")
|
||||||
|
return rank, world_size, local_rank
|
||||||
|
|
||||||
|
|
||||||
|
def all_gather_eval_result(result):
|
||||||
|
world_size = torch.distributed.get_world_size()
|
||||||
|
payload = [None for _ in range(world_size)]
|
||||||
|
torch.distributed.all_gather_object(payload, result)
|
||||||
|
return payload
|
||||||
|
|
||||||
|
|
||||||
def build_process(cfg, dataset):
|
def build_process(cfg, dataset):
|
||||||
process = {}
|
process = {}
|
||||||
for col in cfg.dataset.keys_to_cache:
|
for col in cfg.dataset.keys_to_cache:
|
||||||
@@ -311,6 +400,22 @@ def shard_eval_cases(eval_episodes, eval_start_idx, num_shards):
|
|||||||
return shards
|
return shards
|
||||||
|
|
||||||
|
|
||||||
|
def get_rank_eval_subset(eval_episodes, eval_start_idx, rank, world_size):
|
||||||
|
if world_size < 1:
|
||||||
|
raise ValueError("world_size must be >= 1")
|
||||||
|
if rank < 0 or rank >= world_size:
|
||||||
|
raise ValueError("rank must be in [0, world_size)")
|
||||||
|
|
||||||
|
total = len(eval_episodes)
|
||||||
|
shard_sizes = [total // world_size] * world_size
|
||||||
|
for idx in range(total % world_size):
|
||||||
|
shard_sizes[idx] += 1
|
||||||
|
|
||||||
|
start = sum(shard_sizes[:rank])
|
||||||
|
end = start + shard_sizes[rank]
|
||||||
|
return eval_episodes[start:end], eval_start_idx[start:end]
|
||||||
|
|
||||||
|
|
||||||
def run_eval_subset(
|
def run_eval_subset(
|
||||||
cfg: DictConfig,
|
cfg: DictConfig,
|
||||||
eval_episodes,
|
eval_episodes,
|
||||||
@@ -319,6 +424,7 @@ def run_eval_subset(
|
|||||||
*,
|
*,
|
||||||
device_override: str | None = None,
|
device_override: str | None = None,
|
||||||
enable_profile: bool = True,
|
enable_profile: bool = True,
|
||||||
|
before_evaluate=None,
|
||||||
):
|
):
|
||||||
local_cfg = OmegaConf.create(OmegaConf.to_container(cfg, resolve=False))
|
local_cfg = OmegaConf.create(OmegaConf.to_container(cfg, resolve=False))
|
||||||
local_cfg.eval.num_eval = len(eval_episodes)
|
local_cfg.eval.num_eval = len(eval_episodes)
|
||||||
@@ -376,6 +482,11 @@ def run_eval_subset(
|
|||||||
if str(device).startswith("cuda") and torch.cuda.is_available():
|
if str(device).startswith("cuda") and torch.cuda.is_available():
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
if before_evaluate is not None:
|
||||||
|
before_evaluate()
|
||||||
|
if str(device).startswith("cuda") and torch.cuda.is_available():
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
def evaluate_subset(episodes, start_indices, *, eval_cfg=local_cfg):
|
def evaluate_subset(episodes, start_indices, *, eval_cfg=local_cfg):
|
||||||
return world.evaluate_from_dataset(
|
return world.evaluate_from_dataset(
|
||||||
dataset,
|
dataset,
|
||||||
@@ -421,6 +532,15 @@ def maybe_run_compile_warmup(cfg, eval_episodes, eval_start_idx):
|
|||||||
print("Skipping compile warmup because multi_gpu.enabled=true uses spawned workers.")
|
print("Skipping compile warmup because multi_gpu.enabled=true uses spawned workers.")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if get_multi_node_cfg(cfg)["enabled"]:
|
||||||
|
rank, world_size, local_rank = get_rank_context(cfg)
|
||||||
|
eval_episodes, eval_start_idx = get_rank_eval_subset(
|
||||||
|
eval_episodes, eval_start_idx, rank, world_size
|
||||||
|
)
|
||||||
|
device_override = f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu"
|
||||||
|
else:
|
||||||
|
device_override = None
|
||||||
|
|
||||||
warmup_count = min(int(warmup_cfg["num_eval"]), len(eval_episodes))
|
warmup_count = min(int(warmup_cfg["num_eval"]), len(eval_episodes))
|
||||||
if warmup_count < 1:
|
if warmup_count < 1:
|
||||||
return
|
return
|
||||||
@@ -439,6 +559,7 @@ def maybe_run_compile_warmup(cfg, eval_episodes, eval_start_idx):
|
|||||||
eval_episodes[:warmup_count].tolist(),
|
eval_episodes[:warmup_count].tolist(),
|
||||||
eval_start_idx[:warmup_count].tolist(),
|
eval_start_idx[:warmup_count].tolist(),
|
||||||
Path(tmpdir),
|
Path(tmpdir),
|
||||||
|
device_override=device_override,
|
||||||
enable_profile=False,
|
enable_profile=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -551,6 +672,82 @@ def run_multi_gpu_eval(cfg, eval_episodes, eval_start_idx, output_dir: Path):
|
|||||||
"profile_summary_path": None,
|
"profile_summary_path": None,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def combine_eval_results(ordered_results):
|
||||||
|
episode_successes = np.concatenate(
|
||||||
|
[
|
||||||
|
np.asarray(result["metrics"]["episode_successes"], dtype=np.bool_)
|
||||||
|
for result in ordered_results
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
seeds = None
|
||||||
|
shard_seeds = [result["metrics"].get("seeds") for result in ordered_results]
|
||||||
|
if all(seed is not None for seed in shard_seeds):
|
||||||
|
seeds = np.concatenate(shard_seeds)
|
||||||
|
|
||||||
|
metrics = {
|
||||||
|
"success_rate": float(np.sum(episode_successes)) / len(episode_successes) * 100.0,
|
||||||
|
"episode_successes": episode_successes,
|
||||||
|
"seeds": seeds,
|
||||||
|
}
|
||||||
|
reference = ordered_results[0]
|
||||||
|
return metrics, reference
|
||||||
|
|
||||||
|
|
||||||
|
def run_multi_node_eval(cfg, eval_episodes, eval_start_idx, output_dir: Path):
|
||||||
|
rank, world_size, local_rank = get_rank_context(cfg)
|
||||||
|
shard_episodes, shard_start_idx = get_rank_eval_subset(
|
||||||
|
eval_episodes, eval_start_idx, rank, world_size
|
||||||
|
)
|
||||||
|
if len(shard_episodes) == 0:
|
||||||
|
raise ValueError("No evaluation episodes assigned to this rank")
|
||||||
|
|
||||||
|
local_cfg = OmegaConf.create(OmegaConf.to_container(cfg, resolve=False))
|
||||||
|
local_cfg.multi_node.enabled = False
|
||||||
|
if local_cfg.get("multi_gpu") is None:
|
||||||
|
local_cfg.multi_gpu = OmegaConf.create({"enabled": False})
|
||||||
|
else:
|
||||||
|
local_cfg.multi_gpu.enabled = False
|
||||||
|
|
||||||
|
device = f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu"
|
||||||
|
preload_cfg = get_preload_wait_cfg(cfg)
|
||||||
|
if preload_cfg["enabled"]:
|
||||||
|
if not torch.distributed.is_available():
|
||||||
|
raise RuntimeError("torch.distributed is required for preload_wait")
|
||||||
|
if not torch.distributed.is_initialized():
|
||||||
|
torch.distributed.init_process_group(backend=get_multi_node_cfg(cfg)["backend"])
|
||||||
|
|
||||||
|
result = run_eval_subset(
|
||||||
|
local_cfg,
|
||||||
|
list(shard_episodes),
|
||||||
|
list(shard_start_idx),
|
||||||
|
output_dir,
|
||||||
|
device_override=device,
|
||||||
|
enable_profile=False,
|
||||||
|
before_evaluate=lambda: wait_for_preload_signal(cfg, rank=rank),
|
||||||
|
)
|
||||||
|
if not torch.distributed.is_available():
|
||||||
|
raise RuntimeError("torch.distributed is required for multi-node evaluation")
|
||||||
|
if not torch.distributed.is_initialized():
|
||||||
|
torch.distributed.init_process_group(backend=get_multi_node_cfg(cfg)["backend"])
|
||||||
|
|
||||||
|
gathered = all_gather_eval_result(result)
|
||||||
|
metrics, reference = combine_eval_results(gathered)
|
||||||
|
combined = {
|
||||||
|
"metrics": metrics,
|
||||||
|
"evaluation_time": max(item["evaluation_time"] for item in gathered),
|
||||||
|
"inference_precision": reference["inference_precision"],
|
||||||
|
"compile_target": reference["compile_target"],
|
||||||
|
"compile_mode": reference["compile_mode"],
|
||||||
|
"profile_dir": None,
|
||||||
|
"profile_summary_path": None,
|
||||||
|
}
|
||||||
|
torch.distributed.barrier()
|
||||||
|
if rank != 0:
|
||||||
|
return None
|
||||||
|
return combined
|
||||||
|
|
||||||
@hydra.main(version_base=None, config_path="./config/eval", config_name="pusht")
|
@hydra.main(version_base=None, config_path="./config/eval", config_name="pusht")
|
||||||
def run(cfg: DictConfig):
|
def run(cfg: DictConfig):
|
||||||
"""Run evaluation of dinowm vs random policy."""
|
"""Run evaluation of dinowm vs random policy."""
|
||||||
@@ -565,7 +762,14 @@ def run(cfg: DictConfig):
|
|||||||
|
|
||||||
maybe_run_compile_warmup(cfg, eval_episodes, eval_start_idx)
|
maybe_run_compile_warmup(cfg, eval_episodes, eval_start_idx)
|
||||||
|
|
||||||
if get_multi_gpu_cfg(cfg)["enabled"]:
|
if get_multi_node_cfg(cfg)["enabled"] and get_multi_gpu_cfg(cfg)["enabled"]:
|
||||||
|
raise ValueError("multi_node.enabled and multi_gpu.enabled are mutually exclusive")
|
||||||
|
|
||||||
|
if get_multi_node_cfg(cfg)["enabled"]:
|
||||||
|
eval_result = run_multi_node_eval(cfg, eval_episodes, eval_start_idx, output_dir)
|
||||||
|
if eval_result is None:
|
||||||
|
return
|
||||||
|
elif get_multi_gpu_cfg(cfg)["enabled"]:
|
||||||
if profile_cfg["enabled"]:
|
if profile_cfg["enabled"]:
|
||||||
raise ValueError("Profiling is not supported together with multi_gpu.enabled=true")
|
raise ValueError("Profiling is not supported together with multi_gpu.enabled=true")
|
||||||
eval_result = run_multi_gpu_eval(cfg, eval_episodes, eval_start_idx, output_dir)
|
eval_result = run_multi_gpu_eval(cfg, eval_episodes, eval_start_idx, output_dir)
|
||||||
|
|||||||
255
scripts/launch_multinode_eval.sh
Executable file
255
scripts/launch_multinode_eval.sh
Executable file
@@ -0,0 +1,255 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
# Launch 2-node LeWM evaluation from node-3.
|
||||||
|
#
|
||||||
|
# Defaults match the current cluster layout:
|
||||||
|
# node-3: 10.16.200.9, node_rank=0
|
||||||
|
# node-2: 10.16.200.8, node_rank=1
|
||||||
|
# Each node runs two local torchrun processes for two visible GPUs.
|
||||||
|
|
||||||
|
REPO_ROOT="${REPO_ROOT:-/home/lewm/lewm}"
|
||||||
|
REMOTE_HOST="${REMOTE_HOST:-lewm@10.16.200.8}"
|
||||||
|
MASTER_ADDR="${MASTER_ADDR:-10.16.200.9}"
|
||||||
|
MASTER_PORT="${MASTER_PORT:-29500}"
|
||||||
|
|
||||||
|
NNODES="${NNODES:-2}"
|
||||||
|
NPROC_PER_NODE="${NPROC_PER_NODE:-2}"
|
||||||
|
CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-0,1}"
|
||||||
|
STABLEWM_HOME="${STABLEWM_HOME:-/home/lewm/.stable-wm}"
|
||||||
|
|
||||||
|
CONFIG_NAME="${CONFIG_NAME:-pusht.yaml}"
|
||||||
|
POLICY="${POLICY:-pusht/lewm}"
|
||||||
|
OUTPUT_FILENAME="${OUTPUT_FILENAME:-pusht_multinode_results.txt}"
|
||||||
|
EXTRA_ARGS="${EXTRA_ARGS:-}"
|
||||||
|
DRY_RUN="${DRY_RUN:-0}"
|
||||||
|
TAIL_LOGS="${TAIL_LOGS:-1}"
|
||||||
|
PRELOAD_WAIT="${PRELOAD_WAIT:-0}"
|
||||||
|
PRELOAD_SIGNAL_FILE="${PRELOAD_SIGNAL_FILE:-/tmp/lewm_preload_start}"
|
||||||
|
PRELOAD_CLEAR_SIGNAL="${PRELOAD_CLEAR_SIGNAL:-1}"
|
||||||
|
|
||||||
|
LOG_DIR="${LOG_DIR:-${REPO_ROOT}/logs/multinode}"
|
||||||
|
mkdir -p "${LOG_DIR}"
|
||||||
|
RUN_ID="$(date +%Y%m%d_%H%M%S)"
|
||||||
|
LOCAL_LOG="${LOG_DIR}/${RUN_ID}_node3_rank0.log"
|
||||||
|
REMOTE_LOG="${LOG_DIR}/${RUN_ID}_node2_rank1.log"
|
||||||
|
|
||||||
|
SSH_OPTS=(
|
||||||
|
-F /dev/null
|
||||||
|
-o StrictHostKeyChecking=no
|
||||||
|
-o ServerAliveInterval=30
|
||||||
|
-o ServerAliveCountMax=20
|
||||||
|
)
|
||||||
|
|
||||||
|
COMMON_ARGS=(
|
||||||
|
"--config-name=${CONFIG_NAME}"
|
||||||
|
"policy=${POLICY}"
|
||||||
|
"multi_node.enabled=true"
|
||||||
|
"output.filename=${OUTPUT_FILENAME}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if [[ "${PRELOAD_WAIT}" == "1" ]]; then
|
||||||
|
COMMON_ARGS+=(
|
||||||
|
"preload_wait.enabled=true"
|
||||||
|
"preload_wait.file=${PRELOAD_SIGNAL_FILE}"
|
||||||
|
)
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [[ -n "${EXTRA_ARGS}" ]]; then
|
||||||
|
# shellcheck disable=SC2206
|
||||||
|
COMMON_ARGS+=(${EXTRA_ARGS})
|
||||||
|
fi
|
||||||
|
|
||||||
|
make_command() {
|
||||||
|
local node_rank="$1"
|
||||||
|
local repo_q cuda_q stablewm_q arg_q eval_args
|
||||||
|
printf -v repo_q '%q' "${REPO_ROOT}"
|
||||||
|
printf -v cuda_q '%q' "${CUDA_VISIBLE_DEVICES}"
|
||||||
|
printf -v stablewm_q '%q' "${STABLEWM_HOME}"
|
||||||
|
|
||||||
|
eval_args=""
|
||||||
|
for arg in "${COMMON_ARGS[@]}"; do
|
||||||
|
printf -v arg_q '%q' "${arg}"
|
||||||
|
eval_args+=" ${arg_q}"
|
||||||
|
done
|
||||||
|
|
||||||
|
printf 'cd %s && source .venv/bin/activate && export CUDA_VISIBLE_DEVICES=%s && export STABLEWM_HOME=%s && torchrun --nnodes=%q --nproc_per_node=%q --node_rank=%q --master_addr=%q --master_port=%q eval.py%s' \
|
||||||
|
"${repo_q}" \
|
||||||
|
"${cuda_q}" \
|
||||||
|
"${stablewm_q}" \
|
||||||
|
"${NNODES}" \
|
||||||
|
"${NPROC_PER_NODE}" \
|
||||||
|
"${node_rank}" \
|
||||||
|
"${MASTER_ADDR}" \
|
||||||
|
"${MASTER_PORT}" \
|
||||||
|
"${eval_args}"
|
||||||
|
}
|
||||||
|
|
||||||
|
REMOTE_CMD="$(make_command 1)"
|
||||||
|
LOCAL_CMD="$(make_command 0)"
|
||||||
|
printf -v REMOTE_CMD_Q '%q' "${REMOTE_CMD}"
|
||||||
|
|
||||||
|
REMOTE_PID=""
|
||||||
|
LOCAL_PID=""
|
||||||
|
LOCAL_TAIL_PID=""
|
||||||
|
REMOTE_TAIL_PID=""
|
||||||
|
REMOTE_CLEANUP_CMD=""
|
||||||
|
REMOTE_CLEANUP_CMD_Q=""
|
||||||
|
|
||||||
|
start_log_tail() {
|
||||||
|
local label="$1"
|
||||||
|
local log_file="$2"
|
||||||
|
local label_q log_q
|
||||||
|
|
||||||
|
printf -v label_q '%q' "${label}"
|
||||||
|
printf -v log_q '%q' "${log_file}"
|
||||||
|
setsid bash -lc "tail -n +1 -F ${log_q} 2>/dev/null | sed -u 's/^/[${label_q}] /'" &
|
||||||
|
}
|
||||||
|
|
||||||
|
stop_log_tails() {
|
||||||
|
local pid
|
||||||
|
for pid in "${LOCAL_TAIL_PID}" "${REMOTE_TAIL_PID}"; do
|
||||||
|
if [[ -n "${pid}" ]] && kill -0 "${pid}" 2>/dev/null; then
|
||||||
|
kill -TERM "-${pid}" 2>/dev/null || kill -TERM "${pid}" 2>/dev/null || true
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
}
|
||||||
|
|
||||||
|
remote_cleanup_command() {
|
||||||
|
local pattern_q
|
||||||
|
local patterns=(
|
||||||
|
"torchrun .*--master_addr=${MASTER_ADDR} .*--master_port=${MASTER_PORT} .*eval.py"
|
||||||
|
"torchrun .*--master_port=${MASTER_PORT} .*eval.py"
|
||||||
|
"python.*eval.py .*output.filename=${OUTPUT_FILENAME}"
|
||||||
|
)
|
||||||
|
|
||||||
|
printf 'set +e; '
|
||||||
|
for pattern in "${patterns[@]}"; do
|
||||||
|
printf -v pattern_q '%q' "${pattern}"
|
||||||
|
printf 'pkill -TERM -f %s 2>/dev/null; ' "${pattern_q}"
|
||||||
|
done
|
||||||
|
printf 'sleep 2; '
|
||||||
|
for pattern in "${patterns[@]}"; do
|
||||||
|
printf -v pattern_q '%q' "${pattern}"
|
||||||
|
printf 'pkill -KILL -f %s 2>/dev/null; ' "${pattern_q}"
|
||||||
|
done
|
||||||
|
printf 'true'
|
||||||
|
}
|
||||||
|
|
||||||
|
cleanup() {
|
||||||
|
local status="$?"
|
||||||
|
trap - INT TERM EXIT
|
||||||
|
|
||||||
|
if [[ "${status}" -eq 0 ]]; then
|
||||||
|
return 0
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo
|
||||||
|
echo "Stopping multi-node eval..."
|
||||||
|
stop_log_tails
|
||||||
|
|
||||||
|
if [[ -n "${LOCAL_PID}" ]] && kill -0 "${LOCAL_PID}" 2>/dev/null; then
|
||||||
|
kill -TERM "-${LOCAL_PID}" 2>/dev/null || kill -TERM "${LOCAL_PID}" 2>/dev/null || true
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [[ -n "${REMOTE_PID}" ]] && kill -0 "${REMOTE_PID}" 2>/dev/null; then
|
||||||
|
kill -TERM "${REMOTE_PID}" 2>/dev/null || true
|
||||||
|
fi
|
||||||
|
|
||||||
|
ssh "${SSH_OPTS[@]}" "${REMOTE_HOST}" "bash -lc ${REMOTE_CLEANUP_CMD_Q}" >/dev/null 2>&1 || true
|
||||||
|
|
||||||
|
if [[ -n "${LOCAL_PID}" ]] && kill -0 "${LOCAL_PID}" 2>/dev/null; then
|
||||||
|
sleep 2
|
||||||
|
kill -KILL "-${LOCAL_PID}" 2>/dev/null || kill -KILL "${LOCAL_PID}" 2>/dev/null || true
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "Cleanup requested. Check logs if any process was already exiting:"
|
||||||
|
echo " local: ${LOCAL_LOG}"
|
||||||
|
echo " remote: ${REMOTE_LOG}"
|
||||||
|
exit "${status}"
|
||||||
|
}
|
||||||
|
|
||||||
|
trap cleanup INT TERM EXIT
|
||||||
|
|
||||||
|
REMOTE_CLEANUP_CMD="$(remote_cleanup_command)"
|
||||||
|
printf -v REMOTE_CLEANUP_CMD_Q '%q' "${REMOTE_CLEANUP_CMD}"
|
||||||
|
|
||||||
|
echo "Launching multi-node eval"
|
||||||
|
echo " master: ${MASTER_ADDR}:${MASTER_PORT}"
|
||||||
|
echo " remote: ${REMOTE_HOST}"
|
||||||
|
echo " repo: ${REPO_ROOT}"
|
||||||
|
echo " stablewm: ${STABLEWM_HOME}"
|
||||||
|
echo " config: ${CONFIG_NAME}"
|
||||||
|
echo " policy: ${POLICY}"
|
||||||
|
echo " output: ${OUTPUT_FILENAME}"
|
||||||
|
echo " extra: ${EXTRA_ARGS:-<none>}"
|
||||||
|
echo " tail logs: ${TAIL_LOGS}"
|
||||||
|
echo " preload wait: ${PRELOAD_WAIT}"
|
||||||
|
if [[ "${PRELOAD_WAIT}" == "1" ]]; then
|
||||||
|
echo " preload signal: ${PRELOAD_SIGNAL_FILE}"
|
||||||
|
echo " start command: touch ${PRELOAD_SIGNAL_FILE}"
|
||||||
|
fi
|
||||||
|
echo " local log: ${LOCAL_LOG}"
|
||||||
|
echo " remote log: ${REMOTE_LOG}"
|
||||||
|
|
||||||
|
if [[ "${DRY_RUN}" == "1" ]]; then
|
||||||
|
echo
|
||||||
|
echo "Remote command:"
|
||||||
|
echo "ssh ${SSH_OPTS[*]} ${REMOTE_HOST} bash -lc ${REMOTE_CMD_Q}"
|
||||||
|
echo
|
||||||
|
echo "Local command:"
|
||||||
|
printf -v LOCAL_CMD_Q '%q' "${LOCAL_CMD}"
|
||||||
|
echo "bash -lc ${LOCAL_CMD_Q}"
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [[ "${PRELOAD_WAIT}" == "1" && "${PRELOAD_CLEAR_SIGNAL}" == "1" ]]; then
|
||||||
|
rm -f "${PRELOAD_SIGNAL_FILE}"
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "Starting remote node_rank=1..."
|
||||||
|
ssh "${SSH_OPTS[@]}" "${REMOTE_HOST}" "bash -lc ${REMOTE_CMD_Q}" >"${REMOTE_LOG}" 2>&1 &
|
||||||
|
REMOTE_PID="$!"
|
||||||
|
|
||||||
|
if [[ "${TAIL_LOGS}" == "1" ]]; then
|
||||||
|
start_log_tail "node2" "${REMOTE_LOG}"
|
||||||
|
REMOTE_TAIL_PID="$!"
|
||||||
|
fi
|
||||||
|
|
||||||
|
sleep 3
|
||||||
|
|
||||||
|
echo "Starting local node_rank=0..."
|
||||||
|
set +e
|
||||||
|
setsid bash -lc "${LOCAL_CMD}" >"${LOCAL_LOG}" 2>&1 &
|
||||||
|
LOCAL_PID="$!"
|
||||||
|
|
||||||
|
if [[ "${TAIL_LOGS}" == "1" ]]; then
|
||||||
|
start_log_tail "node3" "${LOCAL_LOG}"
|
||||||
|
LOCAL_TAIL_PID="$!"
|
||||||
|
fi
|
||||||
|
|
||||||
|
wait "${LOCAL_PID}"
|
||||||
|
LOCAL_STATUS="$?"
|
||||||
|
|
||||||
|
wait "${REMOTE_PID}"
|
||||||
|
REMOTE_STATUS="$?"
|
||||||
|
set -e
|
||||||
|
|
||||||
|
stop_log_tails
|
||||||
|
trap - INT TERM EXIT
|
||||||
|
|
||||||
|
echo "Local status: ${LOCAL_STATUS}"
|
||||||
|
echo "Remote status: ${REMOTE_STATUS}"
|
||||||
|
echo "Local log: ${LOCAL_LOG}"
|
||||||
|
echo "Remote log: ${REMOTE_LOG}"
|
||||||
|
|
||||||
|
if [[ "${LOCAL_STATUS}" -ne 0 || "${REMOTE_STATUS}" -ne 0 ]]; then
|
||||||
|
echo "Multi-node eval failed. Tail logs:"
|
||||||
|
echo "===== local tail ====="
|
||||||
|
tail -80 "${LOCAL_LOG}" || true
|
||||||
|
echo "===== remote tail ====="
|
||||||
|
tail -80 "${REMOTE_LOG}" || true
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "Multi-node eval complete."
|
||||||
@@ -44,6 +44,13 @@ mkdir -p "${OUTPUT_DIR}"
|
|||||||
# ROCR_VISIBLE_DEVICES=2,3 HIP_VISIBLE_DEVICES=0,1 MULTI_GPU=1 MULTI_GPU_DEVICES='[0,1]'
|
# ROCR_VISIBLE_DEVICES=2,3 HIP_VISIBLE_DEVICES=0,1 MULTI_GPU=1 MULTI_GPU_DEVICES='[0,1]'
|
||||||
MULTI_GPU="${MULTI_GPU:-0}"
|
MULTI_GPU="${MULTI_GPU:-0}"
|
||||||
MULTI_GPU_DEVICES="${MULTI_GPU_DEVICES:-[0,1]}"
|
MULTI_GPU_DEVICES="${MULTI_GPU_DEVICES:-[0,1]}"
|
||||||
|
MULTI_NODE="${MULTI_NODE:-0}"
|
||||||
|
|
||||||
|
# Multi-node warmup uses the same eval.py entrypoint under torchrun.
|
||||||
|
# Example:
|
||||||
|
# torchrun --nnodes=2 --nproc_per_node=2 --node_rank=0 --master_addr=<ip> --master_port=29500 \
|
||||||
|
# eval.py --config-name=pusht.yaml policy=pusht/lewm multi_node.enabled=true
|
||||||
|
# This script leaves multi-node launch to the caller.
|
||||||
|
|
||||||
COMMON_ARGS=(
|
COMMON_ARGS=(
|
||||||
"eval.num_eval=${WARMUP_NUM_EVAL}"
|
"eval.num_eval=${WARMUP_NUM_EVAL}"
|
||||||
@@ -57,6 +64,12 @@ if [[ "${MULTI_GPU}" == "1" ]]; then
|
|||||||
)
|
)
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
if [[ "${MULTI_NODE}" == "1" ]]; then
|
||||||
|
COMMON_ARGS+=(
|
||||||
|
"multi_node.enabled=true"
|
||||||
|
)
|
||||||
|
fi
|
||||||
|
|
||||||
run_warmup() {
|
run_warmup() {
|
||||||
local config_name="$1"
|
local config_name="$1"
|
||||||
local policy="$2"
|
local policy="$2"
|
||||||
@@ -80,10 +93,11 @@ echo " HIP_VISIBLE_DEVICES: ${HIP_VISIBLE_DEVICES}"
|
|||||||
echo " CUDA_VISIBLE_DEVICES: ${CUDA_VISIBLE_DEVICES}"
|
echo " CUDA_VISIBLE_DEVICES: ${CUDA_VISIBLE_DEVICES}"
|
||||||
echo " WARMUP_NUM_EVAL: ${WARMUP_NUM_EVAL}"
|
echo " WARMUP_NUM_EVAL: ${WARMUP_NUM_EVAL}"
|
||||||
echo " INFERENCE_PRECISION: ${INFERENCE_PRECISION}"
|
echo " INFERENCE_PRECISION: ${INFERENCE_PRECISION}"
|
||||||
echo " MULTI_GPU: ${MULTI_GPU}"
|
echo " MULTI_GPU: ${MULTI_GPU}"
|
||||||
if [[ "${MULTI_GPU}" == "1" ]]; then
|
if [[ "${MULTI_GPU}" == "1" ]]; then
|
||||||
echo " MULTI_GPU_DEVICES: ${MULTI_GPU_DEVICES}"
|
echo " MULTI_GPU_DEVICES: ${MULTI_GPU_DEVICES}"
|
||||||
fi
|
fi
|
||||||
|
echo " MULTI_NODE: ${MULTI_NODE}"
|
||||||
|
|
||||||
# Defaults match the checkpoint names used in this repo. If onsite checkpoint
|
# Defaults match the checkpoint names used in this repo. If onsite checkpoint
|
||||||
# folders differ, override by editing these calls or passing the equivalent
|
# folders differ, override by editing these calls or passing the equivalent
|
||||||
|
|||||||
Reference in New Issue
Block a user