完善mps的vidal机制,多节点并行;补充tn搜索时dask集群搜索的方式
Some checks failed
Build wheels / build (ubuntu-latest, 3.11) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.12) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.13) (push) Has been cancelled
Tests / check (push) Has been cancelled
Tests / build (ubuntu-latest, 3.11) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.12) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.13) (push) Has been cancelled
Some checks failed
Build wheels / build (ubuntu-latest, 3.11) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.12) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.13) (push) Has been cancelled
Tests / check (push) Has been cancelled
Tests / build (ubuntu-latest, 3.11) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.12) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.13) (push) Has been cancelled
This commit is contained in:
148
tools/run_cpu_single_cases.sh
Executable file
148
tools/run_cpu_single_cases.sh
Executable file
@@ -0,0 +1,148 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
# Single-node CPU scale probes for expectation benchmarks.
|
||||
#
|
||||
# Intended for one 96-core / ~500 GiB RAM node. The default "probe" mode runs
|
||||
# moderate MPS and TN cases first. Larger modes are available after checking
|
||||
# runtime and memory from the probe output.
|
||||
|
||||
ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
|
||||
cd "$ROOT_DIR"
|
||||
|
||||
PYTHON_BIN="${PYTHON_BIN:-.venv/bin/python}"
|
||||
PYTHON_FLAGS="${PYTHON_FLAGS:--u}"
|
||||
MPIEXEC="${MPIEXEC:-mpiexec}"
|
||||
TIME_BIN="${TIME_BIN:-/usr/bin/time}"
|
||||
|
||||
MPS_RANKS="${MPS_RANKS:-8}"
|
||||
MPS_THREADS="${MPS_THREADS:-12}"
|
||||
TN_RANKS="${TN_RANKS:-8}"
|
||||
TN_THREADS="${TN_THREADS:-12}"
|
||||
|
||||
export OMP_NUM_THREADS="${OMP_NUM_THREADS:-1}"
|
||||
export MKL_NUM_THREADS="${MKL_NUM_THREADS:-1}"
|
||||
|
||||
estimate_mps_memory() {
|
||||
local nqubits="$1"
|
||||
local bond="$2"
|
||||
"$PYTHON_BIN" - "$nqubits" "$bond" "$MPS_RANKS" <<'PY'
|
||||
import sys
|
||||
n = int(sys.argv[1])
|
||||
chi = int(sys.argv[2])
|
||||
ranks = int(sys.argv[3])
|
||||
resident = n * 2 * chi * chi * 16
|
||||
per_rank = resident / ranks
|
||||
print(
|
||||
"MPS rough resident memory: "
|
||||
f"total={resident / 1024**3:.1f} GiB "
|
||||
f"per_rank={per_rank / 1024**3:.1f} GiB "
|
||||
"(temporary eig/SVD workspaces are additional)"
|
||||
)
|
||||
PY
|
||||
}
|
||||
|
||||
run_timed() {
|
||||
echo
|
||||
echo "--------------------------------------------------------------------------------"
|
||||
echo "$*"
|
||||
echo "--------------------------------------------------------------------------------"
|
||||
"$TIME_BIN" -v "$@"
|
||||
}
|
||||
|
||||
run_mps_case() {
|
||||
local label="$1"
|
||||
local nqubits="$2"
|
||||
local nlayers="$3"
|
||||
local bond="$4"
|
||||
shift 4
|
||||
echo
|
||||
echo "================================================================================"
|
||||
echo "$label"
|
||||
echo "================================================================================"
|
||||
echo "PYTHON_BIN=$PYTHON_BIN MPIEXEC=$MPIEXEC"
|
||||
echo "MPS_RANKS=$MPS_RANKS MPS_THREADS=$MPS_THREADS"
|
||||
echo "OMP_NUM_THREADS=$OMP_NUM_THREADS MKL_NUM_THREADS=$MKL_NUM_THREADS"
|
||||
estimate_mps_memory "$nqubits" "$bond"
|
||||
run_timed "$MPIEXEC" -n "$MPS_RANKS" "$PYTHON_BIN" $PYTHON_FLAGS benchmark_cpu_expectation.py \
|
||||
--mpi --mps \
|
||||
--nqubits "$nqubits" \
|
||||
--nlayers "$nlayers" \
|
||||
--bond "$bond" \
|
||||
--torch-threads "$MPS_THREADS" \
|
||||
"$@"
|
||||
}
|
||||
|
||||
run_tn_case() {
|
||||
local label="$1"
|
||||
local nqubits="$2"
|
||||
local nlayers="$3"
|
||||
shift 3
|
||||
echo
|
||||
echo "================================================================================"
|
||||
echo "$label"
|
||||
echo "================================================================================"
|
||||
echo "PYTHON_BIN=$PYTHON_BIN MPIEXEC=$MPIEXEC"
|
||||
echo "TN_RANKS=$TN_RANKS TN_THREADS=$TN_THREADS"
|
||||
echo "OMP_NUM_THREADS=$OMP_NUM_THREADS MKL_NUM_THREADS=$MKL_NUM_THREADS"
|
||||
echo "TN memory is contraction-tree dependent; increase --tn-target-slices if RSS is high."
|
||||
run_timed "$MPIEXEC" -n "$TN_RANKS" "$PYTHON_BIN" $PYTHON_FLAGS benchmark_cpu_expectation.py \
|
||||
--mpi \
|
||||
--nqubits "$nqubits" \
|
||||
--nlayers "$nlayers" \
|
||||
--torch-threads "$TN_THREADS" \
|
||||
"$@"
|
||||
}
|
||||
|
||||
case "${1:-help}" in
|
||||
probe)
|
||||
run_mps_case "MPS probe: n=40 layers=30 bond=2048" 40 30 2048 \
|
||||
--circuits brickwall_cnot \
|
||||
--observables ring_xz
|
||||
|
||||
run_tn_case "TN probe: n=28 layers=12 target_slices=8" 28 12 \
|
||||
--circuits brickwall_cnot \
|
||||
--observables ring_xz \
|
||||
--tn-target-slices 8
|
||||
;;
|
||||
|
||||
mps-medium)
|
||||
run_mps_case "MPS medium: n=56 layers=40 bond=3072" 56 40 3072 \
|
||||
--circuits brickwall_cnot reversed_cnot shifted_cz rxx_rzz \
|
||||
--observables ring_xz open_zz mixed_local range2_xx
|
||||
;;
|
||||
|
||||
mps-long)
|
||||
run_mps_case "MPS long: n=64 layers=48 bond=4096" 64 48 4096 \
|
||||
--circuits brickwall_cnot reversed_cnot shifted_cz rxx_rzz \
|
||||
--observables ring_xz open_zz mixed_local range2_xx
|
||||
;;
|
||||
|
||||
tn-medium)
|
||||
run_tn_case "TN medium: n=32 layers=16 target_slices=16" 32 16 \
|
||||
--circuits brickwall_cnot shifted_cz rxx_rzz \
|
||||
--observables ring_xz open_zz range2_xx \
|
||||
--tn-target-slices 16
|
||||
;;
|
||||
|
||||
tn-long)
|
||||
run_tn_case "TN long: n=36 layers=20 target_slices=32" 36 20 \
|
||||
--circuits brickwall_cnot shifted_cz rxx_rzz \
|
||||
--observables ring_xz open_zz range2_xx \
|
||||
--tn-target-slices 32
|
||||
;;
|
||||
|
||||
help|*)
|
||||
cat >&2 <<'EOF'
|
||||
Usage: tools/run_cpu_single_cases.sh [probe|mps-medium|mps-long|tn-medium|tn-long]
|
||||
|
||||
Common overrides:
|
||||
PYTHON_BIN=.venv/bin/python
|
||||
MPIEXEC=mpiexec
|
||||
MPS_RANKS=8 MPS_THREADS=12
|
||||
TN_RANKS=8 TN_THREADS=12
|
||||
OMP_NUM_THREADS=1 MKL_NUM_THREADS=1
|
||||
EOF
|
||||
exit 2
|
||||
;;
|
||||
esac
|
||||
Reference in New Issue
Block a user