#!/usr/bin/env bash set -euo pipefail ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" cd "$ROOT_DIR" CASE="${CASE:-main1}" OBSERVABLES="${OBSERVABLES:-long_z_string}" NQUBITS="${NQUBITS:-34}" NLAYERS="${NLAYERS:-20}" TORCH_THREADS="${TORCH_THREADS:-48}" SEARCH_REPEATS="${SEARCH_REPEATS:-2048}" SEARCH_TIME="${SEARCH_TIME:-300}" TN_TARGET_SIZE="${TN_TARGET_SIZE:-17179869184}" TN_TARGET_SLICES="${TN_TARGET_SLICES:-}" PYTHON_BIN="${PYTHON_BIN:-.venv/bin/python}" DTYPE="${DTYPE:-complex64}" TREE_DIR="${TREE_DIR:-trees/contest_tn}" DASK_ADDRESS="${DASK_ADDRESS:-tcp://10.20.1.103:8786}" DASK_EXPECTED_WORKERS="${DASK_EXPECTED_WORKERS:-}" DASK_WAIT_FOR_WORKERS="${DASK_WAIT_FOR_WORKERS:-1}" DASK_WAIT_TIMEOUT="${DASK_WAIT_TIMEOUT:-600}" TN_DEBUG_TRIALS="${TN_DEBUG_TRIALS:-0}" MPIEXEC="${MPIEXEC:-mpirun}" MPIEXEC_FULL="${MPIEXEC_FULL:-}" MPI_HOSTS="${MPI_HOSTS:-}" MPI_HOSTFILE="${MPI_HOSTFILE:-${HOSTFILE:-}}" MPI_RANKS="${MPI_RANKS:-}" MPI_PE="${MPI_PE:-$TORCH_THREADS}" MPI_MAP_BY="${MPI_MAP_BY:-ppr:1:numa:PE=$MPI_PE}" MPI_BIND_TO="${MPI_BIND_TO:-core}" MPI_REPORT_BINDINGS="${MPI_REPORT_BINDINGS:-0}" MPI_EXPORT_ENV="${MPI_EXPORT_ENV:-1}" TN_CONTRACT_ENV_CHECK="${TN_CONTRACT_ENV_CHECK:-1}" SYNC_TREES="${SYNC_TREES:-1}" SYNC_HOSTS="${SYNC_HOSTS:-${WORKER_HOSTS:-}}" SSH_BIN="${SSH_BIN:-ssh}" DASK_CLUSTER_MANAGED="${DASK_CLUSTER_MANAGED:-0}" export TCM_ENABLE="${TCM_ENABLE:-1}" export OMP_NUM_THREADS="${OMP_NUM_THREADS:-$TORCH_THREADS}" export MKL_NUM_THREADS="${MKL_NUM_THREADS:-$TORCH_THREADS}" source "$ROOT_DIR/tools/qibotn_torch_mt_env.sh" tn_slice_args=(--tn-target-size "$TN_TARGET_SIZE") if [[ -n "$TN_TARGET_SLICES" ]]; then tn_slice_args+=(--tn-target-slices "$TN_TARGET_SLICES") fi cleanup_dask_cluster() { local status=$? if [[ "$DASK_CLUSTER_MANAGED" == "1" ]]; then set +e tools/manage_tn_dask_cluster.sh stop >/dev/null 2>&1 || true fi exit "$status" } trap cleanup_dask_cluster EXIT INT TERM HUP sum_host_slots() { local hosts="$1" local total=0 local item slots IFS=',' read -r -a host_items <<< "$hosts" for item in "${host_items[@]}"; do if [[ "$item" == *:* ]]; then slots="${item##*:}" else slots=1 fi total=$((total + slots)) done echo "$total" } count_hosts() { local hosts="$1" local count=0 local item IFS=' ' read -r -a host_items <<< "$hosts" for item in "${host_items[@]}"; do [[ -n "$item" ]] && count=$((count + 1)) done echo "$count" } wait_for_dask_workers() { [[ "$DASK_WAIT_FOR_WORKERS" == "1" ]] || return 0 local expected="$DASK_EXPECTED_WORKERS" if [[ -z "$expected" && -n "$WORKER_HOSTS" ]]; then expected=$(( $(count_hosts "$WORKER_HOSTS") * NWORKERS )) fi if [[ -z "$expected" || "$expected" -le 0 ]]; then return 0 fi echo "Waiting for Dask workers: expected=$expected timeout=${DASK_WAIT_TIMEOUT}s" "$PYTHON_BIN" - "$DASK_ADDRESS" "$expected" "$DASK_WAIT_TIMEOUT" <<'PY' import sys import time from distributed import Client address, expected, timeout = sys.argv[1], int(sys.argv[2]), int(sys.argv[3]) deadline = time.time() + timeout client = Client(address) try: while True: info = client.scheduler_info(n_workers=-1) workers = info.get("workers", {}) count = len(workers) if count >= expected: print(f"dask_workers_ready count={count} expected={expected}", flush=True) break if time.time() >= deadline: print( f"dask_workers_wait_timeout count={count} expected={expected}", flush=True, ) break time.sleep(2) finally: client.close() PY } append_mpi_env_args() { [[ "$MPI_EXPORT_ENV" == "1" ]] || return 0 mpi_prefix+=( -x "LD_PRELOAD=${LD_PRELOAD:-}" -x "BLIS_NUM_THREADS=$BLIS_NUM_THREADS" -x "OMP_NUM_THREADS=$OMP_NUM_THREADS" -x "MKL_NUM_THREADS=$MKL_NUM_THREADS" -x "OMP_PROC_BIND=$OMP_PROC_BIND" -x "OMP_PLACES=$OMP_PLACES" ) } build_mpi_prefix() { if [[ -n "$MPIEXEC_FULL" ]]; then # shellcheck disable=SC2206 mpi_prefix=($MPIEXEC_FULL) append_mpi_env_args return fi local ranks="$MPI_RANKS" if [[ -z "$ranks" && -n "$MPI_HOSTS" ]]; then ranks="$(sum_host_slots "$MPI_HOSTS")" fi if [[ -z "$ranks" ]]; then ranks=2 fi mpi_prefix=( "$MPIEXEC" --map-by "$MPI_MAP_BY" --bind-to "$MPI_BIND_TO" -np "$ranks" ) if [[ "$MPI_REPORT_BINDINGS" == "1" ]]; then mpi_prefix+=(--report-bindings) fi append_mpi_env_args if [[ -n "$MPI_HOSTS" ]]; then mpi_prefix+=(-host "$MPI_HOSTS") elif [[ -n "$MPI_HOSTFILE" ]]; then mpi_prefix+=(-hostfile "$MPI_HOSTFILE") fi } is_local_host() { local host="$1" [[ "$host" == "localhost" || "$host" == "127.0.0.1" ]] && return 0 [[ "$host" == "$(hostname)" ]] && return 0 [[ "$host" == "$(hostname -f 2>/dev/null || true)" ]] && return 0 hostname -I 2>/dev/null | tr ' ' '\n' | grep -qx "$host" } sync_trees_to_hosts() { [[ "$SYNC_TREES" == "1" ]] || return 0 [[ -n "$SYNC_HOSTS" ]] || return 0 local src_dir="$TREE_DIR" local dst_dir="$TREE_DIR" if [[ "$TREE_DIR" != /* ]]; then src_dir="$ROOT_DIR/$TREE_DIR" dst_dir="$ROOT_DIR/$TREE_DIR" fi for host in $SYNC_HOSTS; do is_local_host "$host" && continue echo "Sync tree dir to $host:$dst_dir" "$SSH_BIN" "$host" "mkdir -p $(printf '%q' "$dst_dir")" if command -v rsync >/dev/null 2>&1; then rsync -a "$src_dir/" "$host:$dst_dir/" else scp -q "$src_dir"/*.pkl "$host:$dst_dir/" fi done } tools/manage_tn_dask_cluster.sh start DASK_CLUSTER_MANAGED=1 wait_for_dask_workers echo "Search with dask: $DASK_ADDRESS" search_args=( --case "$CASE" --nqubits "$NQUBITS" --nlayers "$NLAYERS" --observables $OBSERVABLES --tree-dir "$TREE_DIR" --dask-address "$DASK_ADDRESS" --torch-threads "$TORCH_THREADS" --dtype "$DTYPE" --tn-search-repeats "$SEARCH_REPEATS" --tn-search-time "$SEARCH_TIME" "${tn_slice_args[@]}" ) if [[ -n "$DASK_EXPECTED_WORKERS" ]]; then search_args+=(--dask-expected-workers "$DASK_EXPECTED_WORKERS") fi if [[ "$TN_DEBUG_TRIALS" == "1" ]]; then search_args+=(--tn-debug-trials) fi "$PYTHON_BIN" -u tools/tn_contest_runner.py search "${search_args[@]}" sync_trees_to_hosts build_mpi_prefix echo "Contract with MPI: ${mpi_prefix[*]}" if [[ "$TN_CONTRACT_ENV_CHECK" == "1" ]]; then "${mpi_prefix[@]}" "$PYTHON_BIN" -c "from mpi4py import MPI; import os; \ import torch; \ rank = MPI.COMM_WORLD.Get_rank(); \ blis = []; \ [blis.append(line.strip().split()[-1]) for line in open('/proc/self/maps') if 'libblis' in line and line.strip().split()[-1] not in blis]; \ print('tn_contract_env ' + \ f'rank={rank} ' + \ f'LD_PRELOAD={os.environ.get(\"LD_PRELOAD\", \"\")} ' + \ f'BLIS_NUM_THREADS={os.environ.get(\"BLIS_NUM_THREADS\", \"\")} ' + \ f'OMP_NUM_THREADS={os.environ.get(\"OMP_NUM_THREADS\", \"\")} ' + \ f'MKL_NUM_THREADS={os.environ.get(\"MKL_NUM_THREADS\", \"\")} ' + \ f'OMP_PROC_BIND={os.environ.get(\"OMP_PROC_BIND\", \"\")} ' + \ f'OMP_PLACES={os.environ.get(\"OMP_PLACES\", \"\")} ' + \ f'torch_threads={torch.get_num_threads()} ' + \ f'blis={\";\".join(blis) if blis else \"missing\"}', flush=True)" fi "${mpi_prefix[@]}" "$PYTHON_BIN" -u tools/tn_contest_runner.py contract \ --mpi \ --case "$CASE" \ --nqubits "$NQUBITS" \ --nlayers "$NLAYERS" \ --observables $OBSERVABLES \ --tree-dir "$TREE_DIR" \ --torch-threads "$TORCH_THREADS" \ --dtype "$DTYPE" \ "${tn_slice_args[@]}"