Some checks failed
Build wheels / build (ubuntu-latest, 3.11) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.12) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.13) (push) Has been cancelled
Tests / check (push) Has been cancelled
Tests / build (ubuntu-latest, 3.11) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.12) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.13) (push) Has been cancelled
261 lines
7.4 KiB
Bash
Executable File
261 lines
7.4 KiB
Bash
Executable File
#!/usr/bin/env bash
|
|
set -euo pipefail
|
|
|
|
ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
|
|
cd "$ROOT_DIR"
|
|
|
|
CASE="${CASE:-main1}"
|
|
OBSERVABLES="${OBSERVABLES:-long_z_string}"
|
|
NQUBITS="${NQUBITS:-34}"
|
|
NLAYERS="${NLAYERS:-20}"
|
|
TORCH_THREADS="${TORCH_THREADS:-48}"
|
|
SEARCH_REPEATS="${SEARCH_REPEATS:-2048}"
|
|
SEARCH_TIME="${SEARCH_TIME:-300}"
|
|
TN_TARGET_SIZE="${TN_TARGET_SIZE:-17179869184}"
|
|
TN_TARGET_SLICES="${TN_TARGET_SLICES:-}"
|
|
|
|
PYTHON_BIN="${PYTHON_BIN:-.venv/bin/python}"
|
|
DTYPE="${DTYPE:-complex64}"
|
|
TREE_DIR="${TREE_DIR:-trees/contest_tn}"
|
|
DASK_ADDRESS="${DASK_ADDRESS:-tcp://10.20.1.103:8786}"
|
|
DASK_EXPECTED_WORKERS="${DASK_EXPECTED_WORKERS:-}"
|
|
DASK_WAIT_FOR_WORKERS="${DASK_WAIT_FOR_WORKERS:-1}"
|
|
DASK_WAIT_TIMEOUT="${DASK_WAIT_TIMEOUT:-600}"
|
|
TN_DEBUG_TRIALS="${TN_DEBUG_TRIALS:-0}"
|
|
MPIEXEC="${MPIEXEC:-mpirun}"
|
|
MPIEXEC_FULL="${MPIEXEC_FULL:-}"
|
|
MPI_HOSTS="${MPI_HOSTS:-}"
|
|
MPI_HOSTFILE="${MPI_HOSTFILE:-${HOSTFILE:-}}"
|
|
MPI_RANKS="${MPI_RANKS:-}"
|
|
MPI_PE="${MPI_PE:-$TORCH_THREADS}"
|
|
MPI_MAP_BY="${MPI_MAP_BY:-ppr:1:numa:PE=$MPI_PE}"
|
|
MPI_BIND_TO="${MPI_BIND_TO:-core}"
|
|
MPI_REPORT_BINDINGS="${MPI_REPORT_BINDINGS:-0}"
|
|
MPI_EXPORT_ENV="${MPI_EXPORT_ENV:-1}"
|
|
TN_CONTRACT_ENV_CHECK="${TN_CONTRACT_ENV_CHECK:-1}"
|
|
SYNC_TREES="${SYNC_TREES:-1}"
|
|
SYNC_HOSTS="${SYNC_HOSTS:-${WORKER_HOSTS:-}}"
|
|
SSH_BIN="${SSH_BIN:-ssh}"
|
|
DASK_CLUSTER_MANAGED="${DASK_CLUSTER_MANAGED:-0}"
|
|
|
|
export TCM_ENABLE="${TCM_ENABLE:-1}"
|
|
export OMP_NUM_THREADS="${OMP_NUM_THREADS:-$TORCH_THREADS}"
|
|
export MKL_NUM_THREADS="${MKL_NUM_THREADS:-$TORCH_THREADS}"
|
|
source "$ROOT_DIR/tools/qibotn_torch_mt_env.sh"
|
|
|
|
tn_slice_args=(--tn-target-size "$TN_TARGET_SIZE")
|
|
if [[ -n "$TN_TARGET_SLICES" ]]; then
|
|
tn_slice_args+=(--tn-target-slices "$TN_TARGET_SLICES")
|
|
fi
|
|
|
|
cleanup_dask_cluster() {
|
|
local status=$?
|
|
if [[ "$DASK_CLUSTER_MANAGED" == "1" ]]; then
|
|
set +e
|
|
tools/manage_tn_dask_cluster.sh stop >/dev/null 2>&1 || true
|
|
fi
|
|
exit "$status"
|
|
}
|
|
|
|
trap cleanup_dask_cluster EXIT INT TERM HUP
|
|
|
|
sum_host_slots() {
|
|
local hosts="$1"
|
|
local total=0
|
|
local item slots
|
|
IFS=',' read -r -a host_items <<< "$hosts"
|
|
for item in "${host_items[@]}"; do
|
|
if [[ "$item" == *:* ]]; then
|
|
slots="${item##*:}"
|
|
else
|
|
slots=1
|
|
fi
|
|
total=$((total + slots))
|
|
done
|
|
echo "$total"
|
|
}
|
|
|
|
count_hosts() {
|
|
local hosts="$1"
|
|
local count=0
|
|
local item
|
|
IFS=' ' read -r -a host_items <<< "$hosts"
|
|
for item in "${host_items[@]}"; do
|
|
[[ -n "$item" ]] && count=$((count + 1))
|
|
done
|
|
echo "$count"
|
|
}
|
|
|
|
wait_for_dask_workers() {
|
|
[[ "$DASK_WAIT_FOR_WORKERS" == "1" ]] || return 0
|
|
local expected="$DASK_EXPECTED_WORKERS"
|
|
if [[ -z "$expected" && -n "$WORKER_HOSTS" ]]; then
|
|
expected=$(( $(count_hosts "$WORKER_HOSTS") * NWORKERS ))
|
|
fi
|
|
if [[ -z "$expected" || "$expected" -le 0 ]]; then
|
|
return 0
|
|
fi
|
|
|
|
echo "Waiting for Dask workers: expected=$expected timeout=${DASK_WAIT_TIMEOUT}s"
|
|
"$PYTHON_BIN" - "$DASK_ADDRESS" "$expected" "$DASK_WAIT_TIMEOUT" <<'PY'
|
|
import sys
|
|
import time
|
|
from distributed import Client
|
|
|
|
address, expected, timeout = sys.argv[1], int(sys.argv[2]), int(sys.argv[3])
|
|
deadline = time.time() + timeout
|
|
client = Client(address)
|
|
try:
|
|
while True:
|
|
info = client.scheduler_info(n_workers=-1)
|
|
workers = info.get("workers", {})
|
|
count = len(workers)
|
|
if count >= expected:
|
|
print(f"dask_workers_ready count={count} expected={expected}", flush=True)
|
|
break
|
|
if time.time() >= deadline:
|
|
print(
|
|
f"dask_workers_wait_timeout count={count} expected={expected}",
|
|
flush=True,
|
|
)
|
|
break
|
|
time.sleep(2)
|
|
finally:
|
|
client.close()
|
|
PY
|
|
}
|
|
|
|
append_mpi_env_args() {
|
|
[[ "$MPI_EXPORT_ENV" == "1" ]] || return 0
|
|
mpi_prefix+=(
|
|
-x "LD_PRELOAD=${LD_PRELOAD:-}"
|
|
-x "BLIS_NUM_THREADS=$BLIS_NUM_THREADS"
|
|
-x "OMP_NUM_THREADS=$OMP_NUM_THREADS"
|
|
-x "MKL_NUM_THREADS=$MKL_NUM_THREADS"
|
|
-x "OMP_PROC_BIND=$OMP_PROC_BIND"
|
|
-x "OMP_PLACES=$OMP_PLACES"
|
|
)
|
|
}
|
|
|
|
build_mpi_prefix() {
|
|
if [[ -n "$MPIEXEC_FULL" ]]; then
|
|
# shellcheck disable=SC2206
|
|
mpi_prefix=($MPIEXEC_FULL)
|
|
append_mpi_env_args
|
|
return
|
|
fi
|
|
|
|
local ranks="$MPI_RANKS"
|
|
if [[ -z "$ranks" && -n "$MPI_HOSTS" ]]; then
|
|
ranks="$(sum_host_slots "$MPI_HOSTS")"
|
|
fi
|
|
if [[ -z "$ranks" ]]; then
|
|
ranks=2
|
|
fi
|
|
|
|
mpi_prefix=(
|
|
"$MPIEXEC"
|
|
--map-by "$MPI_MAP_BY"
|
|
--bind-to "$MPI_BIND_TO"
|
|
-np "$ranks"
|
|
)
|
|
if [[ "$MPI_REPORT_BINDINGS" == "1" ]]; then
|
|
mpi_prefix+=(--report-bindings)
|
|
fi
|
|
append_mpi_env_args
|
|
if [[ -n "$MPI_HOSTS" ]]; then
|
|
mpi_prefix+=(-host "$MPI_HOSTS")
|
|
elif [[ -n "$MPI_HOSTFILE" ]]; then
|
|
mpi_prefix+=(-hostfile "$MPI_HOSTFILE")
|
|
fi
|
|
}
|
|
|
|
is_local_host() {
|
|
local host="$1"
|
|
[[ "$host" == "localhost" || "$host" == "127.0.0.1" ]] && return 0
|
|
[[ "$host" == "$(hostname)" ]] && return 0
|
|
[[ "$host" == "$(hostname -f 2>/dev/null || true)" ]] && return 0
|
|
hostname -I 2>/dev/null | tr ' ' '\n' | grep -qx "$host"
|
|
}
|
|
|
|
sync_trees_to_hosts() {
|
|
[[ "$SYNC_TREES" == "1" ]] || return 0
|
|
[[ -n "$SYNC_HOSTS" ]] || return 0
|
|
|
|
local src_dir="$TREE_DIR"
|
|
local dst_dir="$TREE_DIR"
|
|
if [[ "$TREE_DIR" != /* ]]; then
|
|
src_dir="$ROOT_DIR/$TREE_DIR"
|
|
dst_dir="$ROOT_DIR/$TREE_DIR"
|
|
fi
|
|
|
|
for host in $SYNC_HOSTS; do
|
|
is_local_host "$host" && continue
|
|
echo "Sync tree dir to $host:$dst_dir"
|
|
"$SSH_BIN" "$host" "mkdir -p $(printf '%q' "$dst_dir")"
|
|
if command -v rsync >/dev/null 2>&1; then
|
|
rsync -a "$src_dir/" "$host:$dst_dir/"
|
|
else
|
|
scp -q "$src_dir"/*.pkl "$host:$dst_dir/"
|
|
fi
|
|
done
|
|
}
|
|
|
|
tools/manage_tn_dask_cluster.sh start
|
|
DASK_CLUSTER_MANAGED=1
|
|
wait_for_dask_workers
|
|
|
|
echo "Search with dask: $DASK_ADDRESS"
|
|
search_args=(
|
|
--case "$CASE"
|
|
--nqubits "$NQUBITS"
|
|
--nlayers "$NLAYERS"
|
|
--observables $OBSERVABLES
|
|
--tree-dir "$TREE_DIR"
|
|
--dask-address "$DASK_ADDRESS"
|
|
--torch-threads "$TORCH_THREADS"
|
|
--dtype "$DTYPE"
|
|
--tn-search-repeats "$SEARCH_REPEATS"
|
|
--tn-search-time "$SEARCH_TIME"
|
|
"${tn_slice_args[@]}"
|
|
)
|
|
if [[ -n "$DASK_EXPECTED_WORKERS" ]]; then
|
|
search_args+=(--dask-expected-workers "$DASK_EXPECTED_WORKERS")
|
|
fi
|
|
if [[ "$TN_DEBUG_TRIALS" == "1" ]]; then
|
|
search_args+=(--tn-debug-trials)
|
|
fi
|
|
"$PYTHON_BIN" -u tools/tn_contest_runner.py search "${search_args[@]}"
|
|
|
|
sync_trees_to_hosts
|
|
|
|
build_mpi_prefix
|
|
echo "Contract with MPI: ${mpi_prefix[*]}"
|
|
if [[ "$TN_CONTRACT_ENV_CHECK" == "1" ]]; then
|
|
"${mpi_prefix[@]}" "$PYTHON_BIN" -c "from mpi4py import MPI; import os; \
|
|
import torch; \
|
|
rank = MPI.COMM_WORLD.Get_rank(); \
|
|
blis = []; \
|
|
[blis.append(line.strip().split()[-1]) for line in open('/proc/self/maps') if 'libblis' in line and line.strip().split()[-1] not in blis]; \
|
|
print('tn_contract_env ' + \
|
|
f'rank={rank} ' + \
|
|
f'LD_PRELOAD={os.environ.get(\"LD_PRELOAD\", \"\")} ' + \
|
|
f'BLIS_NUM_THREADS={os.environ.get(\"BLIS_NUM_THREADS\", \"\")} ' + \
|
|
f'OMP_NUM_THREADS={os.environ.get(\"OMP_NUM_THREADS\", \"\")} ' + \
|
|
f'MKL_NUM_THREADS={os.environ.get(\"MKL_NUM_THREADS\", \"\")} ' + \
|
|
f'OMP_PROC_BIND={os.environ.get(\"OMP_PROC_BIND\", \"\")} ' + \
|
|
f'OMP_PLACES={os.environ.get(\"OMP_PLACES\", \"\")} ' + \
|
|
f'torch_threads={torch.get_num_threads()} ' + \
|
|
f'blis={\";\".join(blis) if blis else \"missing\"}', flush=True)"
|
|
fi
|
|
"${mpi_prefix[@]}" "$PYTHON_BIN" -u tools/tn_contest_runner.py contract \
|
|
--mpi \
|
|
--case "$CASE" \
|
|
--nqubits "$NQUBITS" \
|
|
--nlayers "$NLAYERS" \
|
|
--observables $OBSERVABLES \
|
|
--tree-dir "$TREE_DIR" \
|
|
--torch-threads "$TORCH_THREADS" \
|
|
--dtype "$DTYPE" \
|
|
"${tn_slice_args[@]}"
|