决赛现场脚本
Some checks failed
Build wheels / build (ubuntu-latest, 3.11) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.12) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.13) (push) Has been cancelled
Tests / check (push) Has been cancelled
Tests / build (ubuntu-latest, 3.11) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.12) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.13) (push) Has been cancelled
Some checks failed
Build wheels / build (ubuntu-latest, 3.11) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.12) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.13) (push) Has been cancelled
Tests / check (push) Has been cancelled
Tests / build (ubuntu-latest, 3.11) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.12) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.13) (push) Has been cancelled
This commit is contained in:
@@ -16,3 +16,4 @@ Files here are intentionally secondary:
|
||||
- `benchmark_tn_mpi.py`, `benchmark_search.py`, `benchmark_slice.py`, `benchmark_contract_sliced.py`, `check_tree.py`: old TN path-search/slicing experiments.
|
||||
- `qibojit_reference_expectation.py`: state-vector reference helper.
|
||||
- `validate_vidal_mpi_correctness.py`: focused Vidal MPI correctness helper.
|
||||
- `mpi_torch_thread_probe.py`: MPI + torch OpenMP affinity and threading probe.
|
||||
|
||||
157
tools/benchmark_qredtea_svd_controls.py
Normal file
157
tools/benchmark_qredtea_svd_controls.py
Normal file
@@ -0,0 +1,157 @@
|
||||
#!/usr/bin/env python
|
||||
"""Benchmark qredtea/qtealeaves SVD control modes.
|
||||
|
||||
This isolates the tensor split used by MPS updates: a rank-2 tensor is split
|
||||
with singular values contracted either left or right, then reconstructed to
|
||||
measure numerical error and timing.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import gc
|
||||
import statistics
|
||||
import time
|
||||
|
||||
import torch
|
||||
|
||||
import qmatchatea
|
||||
from qredtea.torchapi import QteaTorchTensor
|
||||
|
||||
|
||||
def _dtype(name: str):
|
||||
return {
|
||||
"complex64": torch.complex64,
|
||||
"complex128": torch.complex128,
|
||||
"float64": torch.float64,
|
||||
"float32": torch.float32,
|
||||
}[name]
|
||||
|
||||
|
||||
def _random_matrix(shape, dtype, seed):
|
||||
gen = torch.Generator(device="cpu")
|
||||
gen.manual_seed(seed)
|
||||
if dtype.is_complex:
|
||||
real_dtype = torch.float32 if dtype == torch.complex64 else torch.float64
|
||||
real = torch.randn(shape, dtype=real_dtype, generator=gen)
|
||||
imag = torch.randn(shape, dtype=real_dtype, generator=gen)
|
||||
return torch.complex(real, imag).to(dtype)
|
||||
return torch.randn(shape, dtype=dtype, generator=gen)
|
||||
|
||||
|
||||
def _sync():
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
|
||||
|
||||
def run_one(matrix, ctrl, max_bond, contract_singvals, repeats):
|
||||
conv = qmatchatea.QCConvergenceParameters(
|
||||
max_bond_dimension=max_bond,
|
||||
cut_ratio=0.0,
|
||||
svd_ctrl=ctrl,
|
||||
)
|
||||
qtensor = QteaTorchTensor.from_elem_array(matrix, dtype=matrix.dtype, device="cpu")
|
||||
|
||||
times = []
|
||||
rel_error = None
|
||||
kept = None
|
||||
status = "ok"
|
||||
error = ""
|
||||
|
||||
for i in range(repeats):
|
||||
gc.collect()
|
||||
_sync()
|
||||
t0 = time.perf_counter()
|
||||
try:
|
||||
left, right, singvals, _ = qtensor.split_svd(
|
||||
[0],
|
||||
[1],
|
||||
contract_singvals=contract_singvals,
|
||||
conv_params=conv,
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001 - benchmark should keep going
|
||||
status = "error"
|
||||
error = repr(exc)
|
||||
break
|
||||
_sync()
|
||||
times.append(time.perf_counter() - t0)
|
||||
|
||||
if i == repeats - 1:
|
||||
left_matrix = left.elem.reshape(matrix.shape[0], -1)
|
||||
right_matrix = right.elem.reshape(-1, matrix.shape[1])
|
||||
recon = left_matrix @ right_matrix
|
||||
rel_error = (
|
||||
torch.linalg.vector_norm(matrix - recon)
|
||||
/ torch.linalg.vector_norm(matrix)
|
||||
).item()
|
||||
kept = int(singvals.numel())
|
||||
|
||||
return {
|
||||
"ctrl": ctrl,
|
||||
"contract_singvals": contract_singvals,
|
||||
"status": status,
|
||||
"median_ms": float("nan") if not times else statistics.median(times) * 1000,
|
||||
"min_ms": float("nan") if not times else min(times) * 1000,
|
||||
"rel_error": rel_error,
|
||||
"kept": kept,
|
||||
"error": error,
|
||||
}
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--shapes", nargs="+", default=("256x1024", "1024x256", "512x512"))
|
||||
parser.add_argument("--max-bond", type=int, default=128)
|
||||
parser.add_argument("--dtype", choices=("complex64", "complex128", "float32", "float64"), default="complex128")
|
||||
parser.add_argument("--threads", type=int, default=8)
|
||||
parser.add_argument("--repeats", type=int, default=3)
|
||||
parser.add_argument(
|
||||
"--controls",
|
||||
nargs="+",
|
||||
default=("A", "D", "V", "R", "E", "E!", "X", "X!"),
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
torch.set_num_threads(args.threads)
|
||||
dtype = _dtype(args.dtype)
|
||||
|
||||
print(
|
||||
"svd_benchmark "
|
||||
f"dtype={args.dtype} threads={torch.get_num_threads()} "
|
||||
f"max_bond={args.max_bond} repeats={args.repeats}",
|
||||
flush=True,
|
||||
)
|
||||
print(
|
||||
"columns shape contract ctrl status median_ms min_ms kept rel_error error",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
for shape_text in args.shapes:
|
||||
m_text, n_text = shape_text.lower().split("x", 1)
|
||||
shape = (int(m_text), int(n_text))
|
||||
matrix = _random_matrix(shape, dtype, seed=sum(shape))
|
||||
for contract_singvals in ("L", "R"):
|
||||
for ctrl in args.controls:
|
||||
result = run_one(
|
||||
matrix,
|
||||
ctrl=ctrl,
|
||||
max_bond=args.max_bond,
|
||||
contract_singvals=contract_singvals,
|
||||
repeats=args.repeats,
|
||||
)
|
||||
print(
|
||||
f"row shape={shape_text} "
|
||||
f"contract={contract_singvals} "
|
||||
f"ctrl={ctrl} "
|
||||
f"status={result['status']} "
|
||||
f"median_ms={result['median_ms']:.3f} "
|
||||
f"min_ms={result['min_ms']:.3f} "
|
||||
f"kept={result['kept']} "
|
||||
f"rel_error={result['rel_error']} "
|
||||
f"error={result['error']}",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -17,10 +17,10 @@ set -euo pipefail
|
||||
# WORKER_HOSTS="10.20.1.103 10.20.6.101"
|
||||
# NWORKERS=48
|
||||
# NTHREADS=1
|
||||
# ROOT_DIR=/home/yx/qibotn
|
||||
# ROOT_DIR=/home/qibo/qibotn
|
||||
# PYTHON_BIN=.venv/bin/python
|
||||
|
||||
ROOT_DIR="${ROOT_DIR:-/home/yx/qibotn}"
|
||||
ROOT_DIR="${ROOT_DIR:-/home/qibo/qibotn}"
|
||||
PYTHON_BIN="${PYTHON_BIN:-.venv/bin/python}"
|
||||
SCHEDULER_HOST="${SCHEDULER_HOST:-10.20.1.103}"
|
||||
SCHEDULER_PORT="${SCHEDULER_PORT:-8786}"
|
||||
|
||||
182
tools/mpi_torch_thread_probe.py
Normal file
182
tools/mpi_torch_thread_probe.py
Normal file
@@ -0,0 +1,182 @@
|
||||
#!/usr/bin/env python
|
||||
"""Probe MPI rank placement and whether torch CPU ops use multiple threads.
|
||||
|
||||
Run this under mpirun/mpiexec to check:
|
||||
|
||||
* which CPUs each rank is allowed to run on,
|
||||
* whether torch sees the requested intra-op thread count, and
|
||||
* whether a large CPU tensor op actually consumes more CPU time than wall time.
|
||||
|
||||
The script is intentionally small and self-contained so it can be used to debug
|
||||
MPI launcher affinity and torch OpenMP behavior independently from the TN code
|
||||
path.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import socket
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
from mpi4py import MPI
|
||||
|
||||
|
||||
def _dtype_from_name(name):
|
||||
import torch
|
||||
|
||||
mapping = {
|
||||
"float32": torch.float32,
|
||||
"float64": torch.float64,
|
||||
"complex64": torch.complex64,
|
||||
"complex128": torch.complex128,
|
||||
}
|
||||
return mapping[name]
|
||||
|
||||
|
||||
def _make_tensor(shape, dtype):
|
||||
import torch
|
||||
|
||||
if dtype in (torch.complex64, torch.complex128):
|
||||
base = torch.float32 if dtype == torch.complex64 else torch.float64
|
||||
return torch.complex(
|
||||
torch.randn(shape, dtype=base),
|
||||
torch.randn(shape, dtype=base),
|
||||
)
|
||||
return torch.randn(shape, dtype=dtype)
|
||||
|
||||
|
||||
def _bench(label, fn, iters, warmup=2):
|
||||
for _ in range(warmup):
|
||||
fn()
|
||||
|
||||
start_wall = time.perf_counter()
|
||||
start_cpu = time.process_time()
|
||||
checksum = 0.0
|
||||
for _ in range(iters):
|
||||
value = fn()
|
||||
checksum += float(value)
|
||||
wall = time.perf_counter() - start_wall
|
||||
cpu = time.process_time() - start_cpu
|
||||
ratio = cpu / wall if wall > 0 else float("inf")
|
||||
print(
|
||||
f"{label} wall={wall:.3f}s cpu={cpu:.3f}s cpu_over_wall={ratio:.2f} "
|
||||
f"checksum={checksum:.6e}",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
|
||||
def _visible_numa_nodes():
|
||||
nodes = []
|
||||
for path in sorted(Path("/sys/devices/system/node").glob("node[0-9]*")):
|
||||
cpulist = path / "cpulist"
|
||||
if cpulist.exists():
|
||||
nodes.append(f"{path.name}:{cpulist.read_text(encoding='utf-8').strip()}")
|
||||
return ",".join(nodes) if nodes else "unknown"
|
||||
|
||||
|
||||
def _dtype_nbytes(name):
|
||||
return {
|
||||
"float32": 4,
|
||||
"float64": 8,
|
||||
"complex64": 8,
|
||||
"complex128": 16,
|
||||
}[name]
|
||||
|
||||
|
||||
def _format_gib(nbytes):
|
||||
return f"{nbytes / (1024 ** 3):.2f}GiB"
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--threads", type=int, default=48)
|
||||
parser.add_argument("--n", type=int, default=4096)
|
||||
parser.add_argument("--iters", type=int, default=4)
|
||||
parser.add_argument("--dtype", choices=("float32", "float64", "complex64", "complex128"), default="float32")
|
||||
parser.add_argument("--op", choices=("matmul", "tensordot", "both"), default="both")
|
||||
parser.add_argument(
|
||||
"--affinity-only",
|
||||
action="store_true",
|
||||
help="Print MPI/torch placement diagnostics without allocating tensors.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
os.environ.setdefault("OMP_NUM_THREADS", str(args.threads))
|
||||
os.environ.setdefault("MKL_NUM_THREADS", str(args.threads))
|
||||
os.environ.setdefault("OMP_PROC_BIND", "close")
|
||||
os.environ.setdefault("OMP_PLACES", "cores")
|
||||
|
||||
import torch
|
||||
|
||||
comm = MPI.COMM_WORLD
|
||||
rank = comm.Get_rank()
|
||||
size = comm.Get_size()
|
||||
|
||||
torch.set_num_threads(args.threads)
|
||||
try:
|
||||
torch.set_num_interop_threads(1)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
dtype = _dtype_from_name(args.dtype)
|
||||
affinity = sorted(os.sched_getaffinity(0))
|
||||
allowed_list = ""
|
||||
try:
|
||||
with open("/proc/self/status", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
if line.startswith("Cpus_allowed_list:"):
|
||||
allowed_list = line.split(":", 1)[1].strip()
|
||||
break
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
print(
|
||||
f"rank={rank}/{size} host={socket.gethostname()} pid={os.getpid()} "
|
||||
f"affinity_len={len(affinity)} allowed={allowed_list} "
|
||||
f"torch_threads={torch.get_num_threads()} "
|
||||
f"torch_interop={torch.get_num_interop_threads()} "
|
||||
f"OMP_NUM_THREADS={os.environ.get('OMP_NUM_THREADS')} "
|
||||
f"MKL_NUM_THREADS={os.environ.get('MKL_NUM_THREADS')} "
|
||||
f"OMP_PROC_BIND={os.environ.get('OMP_PROC_BIND')} "
|
||||
f"OMP_PLACES={os.environ.get('OMP_PLACES')} "
|
||||
f"visible_numa={_visible_numa_nodes()}",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
if rank == 0:
|
||||
print(torch.__config__.parallel_info(), flush=True)
|
||||
input_bytes = args.n * args.n * _dtype_nbytes(args.dtype)
|
||||
min_live_bytes = 3 * input_bytes
|
||||
print(
|
||||
f"matrix_n={args.n} dtype={args.dtype} "
|
||||
f"one_matrix={_format_gib(input_bytes)} "
|
||||
f"approx_min_live_per_rank={_format_gib(min_live_bytes)} "
|
||||
f"approx_min_live_all_ranks={_format_gib(min_live_bytes * size)}",
|
||||
flush=True,
|
||||
)
|
||||
comm.Barrier()
|
||||
if args.affinity_only:
|
||||
return
|
||||
|
||||
a = _make_tensor((args.n, args.n), dtype)
|
||||
b = _make_tensor((args.n, args.n), dtype)
|
||||
|
||||
def run_matmul():
|
||||
value = (a @ b).sum()
|
||||
return value.real.item() if value.is_complex() else value.item()
|
||||
|
||||
def run_tensordot():
|
||||
value = torch.tensordot(a, b, dims=1)
|
||||
value = value.sum()
|
||||
return value.real.item() if value.is_complex() else value.item()
|
||||
|
||||
if args.op in ("matmul", "both"):
|
||||
_bench("matmul", run_matmul, args.iters)
|
||||
if args.op in ("tensordot", "both"):
|
||||
_bench("tensordot", run_tensordot, args.iters)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
22
tools/qibotn_torch_mt_env.sh
Normal file
22
tools/qibotn_torch_mt_env.sh
Normal file
@@ -0,0 +1,22 @@
|
||||
#!/usr/bin/env bash
|
||||
# Shared runtime setup for CPU torch TN/MPS runs.
|
||||
#
|
||||
# This makes AOCL BLIS use the multithreaded library when available, which is
|
||||
# required for complex64 tensordot/cgemm to actually use all cores on this host.
|
||||
|
||||
QIBOTN_BLIS_MT="${QIBOTN_BLIS_MT:-/home/aocc/aocl/5.2.0/aocc/lib_LP64/libblis-mt.so.5}"
|
||||
|
||||
export BLIS_NUM_THREADS="${BLIS_NUM_THREADS:-${OMP_NUM_THREADS:-1}}"
|
||||
|
||||
if [[ -f "$QIBOTN_BLIS_MT" ]]; then
|
||||
case ":${LD_PRELOAD:-}:" in
|
||||
*":$QIBOTN_BLIS_MT:"*)
|
||||
;;
|
||||
*)
|
||||
export LD_PRELOAD="${LD_PRELOAD:+$LD_PRELOAD:}$QIBOTN_BLIS_MT"
|
||||
;;
|
||||
esac
|
||||
fi
|
||||
|
||||
export OMP_PROC_BIND="${OMP_PROC_BIND:-close}"
|
||||
export OMP_PLACES="${OMP_PLACES:-cores}"
|
||||
@@ -21,6 +21,7 @@ TN_THREADS="${TN_THREADS:-8}"
|
||||
|
||||
export OMP_NUM_THREADS="${OMP_NUM_THREADS:-1}"
|
||||
export MKL_NUM_THREADS="${MKL_NUM_THREADS:-1}"
|
||||
source "$ROOT_DIR/tools/qibotn_torch_mt_env.sh"
|
||||
|
||||
run_mpi() {
|
||||
local ranks="$1"
|
||||
|
||||
@@ -22,6 +22,7 @@ TN_THREADS="${TN_THREADS:-12}"
|
||||
|
||||
export OMP_NUM_THREADS="${OMP_NUM_THREADS:-1}"
|
||||
export MKL_NUM_THREADS="${MKL_NUM_THREADS:-1}"
|
||||
source "$ROOT_DIR/tools/qibotn_torch_mt_env.sh"
|
||||
|
||||
estimate_mps_memory() {
|
||||
local nqubits="$1"
|
||||
|
||||
@@ -11,25 +11,165 @@ NLAYERS="${NLAYERS:-20}"
|
||||
TORCH_THREADS="${TORCH_THREADS:-48}"
|
||||
SEARCH_REPEATS="${SEARCH_REPEATS:-2048}"
|
||||
SEARCH_TIME="${SEARCH_TIME:-300}"
|
||||
TN_TARGET_SIZE="${TN_TARGET_SIZE:-8589934592}"
|
||||
TN_TARGET_SIZE="${TN_TARGET_SIZE:-17179869184}"
|
||||
TN_TARGET_SLICES="${TN_TARGET_SLICES:-}"
|
||||
|
||||
PYTHON_BIN="${PYTHON_BIN:-.venv/bin/python}"
|
||||
DTYPE="${DTYPE:-complex64}"
|
||||
TREE_DIR="${TREE_DIR:-trees/contest_tn}"
|
||||
DASK_ADDRESS="${DASK_ADDRESS:-tcp://10.20.1.103:8786}"
|
||||
MPIEXEC_FULL="${MPIEXEC_FULL:-mpirun -np 4 -hostfile /home/yx/qibotn/hostfile -perhost 2}"
|
||||
DASK_EXPECTED_WORKERS="${DASK_EXPECTED_WORKERS:-}"
|
||||
DASK_WAIT_FOR_WORKERS="${DASK_WAIT_FOR_WORKERS:-1}"
|
||||
DASK_WAIT_TIMEOUT="${DASK_WAIT_TIMEOUT:-600}"
|
||||
TN_DEBUG_TRIALS="${TN_DEBUG_TRIALS:-0}"
|
||||
MPIEXEC="${MPIEXEC:-mpirun}"
|
||||
MPIEXEC_FULL="${MPIEXEC_FULL:-}"
|
||||
MPI_HOSTS="${MPI_HOSTS:-}"
|
||||
MPI_HOSTFILE="${MPI_HOSTFILE:-${HOSTFILE:-}}"
|
||||
MPI_RANKS="${MPI_RANKS:-}"
|
||||
MPI_PE="${MPI_PE:-$TORCH_THREADS}"
|
||||
MPI_MAP_BY="${MPI_MAP_BY:-ppr:1:numa:PE=$MPI_PE}"
|
||||
MPI_BIND_TO="${MPI_BIND_TO:-core}"
|
||||
MPI_REPORT_BINDINGS="${MPI_REPORT_BINDINGS:-0}"
|
||||
MPI_EXPORT_ENV="${MPI_EXPORT_ENV:-1}"
|
||||
TN_CONTRACT_ENV_CHECK="${TN_CONTRACT_ENV_CHECK:-1}"
|
||||
SYNC_TREES="${SYNC_TREES:-1}"
|
||||
SYNC_HOSTS="${SYNC_HOSTS:-${WORKER_HOSTS:-}}"
|
||||
SSH_BIN="${SSH_BIN:-ssh}"
|
||||
DASK_CLUSTER_MANAGED="${DASK_CLUSTER_MANAGED:-0}"
|
||||
|
||||
export TCM_ENABLE="${TCM_ENABLE:-1}"
|
||||
export OMP_NUM_THREADS="${OMP_NUM_THREADS:-$TORCH_THREADS}"
|
||||
export MKL_NUM_THREADS="${MKL_NUM_THREADS:-$TORCH_THREADS}"
|
||||
source "$ROOT_DIR/tools/qibotn_torch_mt_env.sh"
|
||||
|
||||
tn_slice_args=(--tn-target-size "$TN_TARGET_SIZE")
|
||||
if [[ -n "$TN_TARGET_SLICES" ]]; then
|
||||
tn_slice_args+=(--tn-target-slices "$TN_TARGET_SLICES")
|
||||
fi
|
||||
|
||||
cleanup_dask_cluster() {
|
||||
local status=$?
|
||||
if [[ "$DASK_CLUSTER_MANAGED" == "1" ]]; then
|
||||
set +e
|
||||
tools/manage_tn_dask_cluster.sh stop >/dev/null 2>&1 || true
|
||||
fi
|
||||
exit "$status"
|
||||
}
|
||||
|
||||
trap cleanup_dask_cluster EXIT INT TERM HUP
|
||||
|
||||
sum_host_slots() {
|
||||
local hosts="$1"
|
||||
local total=0
|
||||
local item slots
|
||||
IFS=',' read -r -a host_items <<< "$hosts"
|
||||
for item in "${host_items[@]}"; do
|
||||
if [[ "$item" == *:* ]]; then
|
||||
slots="${item##*:}"
|
||||
else
|
||||
slots=1
|
||||
fi
|
||||
total=$((total + slots))
|
||||
done
|
||||
echo "$total"
|
||||
}
|
||||
|
||||
count_hosts() {
|
||||
local hosts="$1"
|
||||
local count=0
|
||||
local item
|
||||
IFS=' ' read -r -a host_items <<< "$hosts"
|
||||
for item in "${host_items[@]}"; do
|
||||
[[ -n "$item" ]] && count=$((count + 1))
|
||||
done
|
||||
echo "$count"
|
||||
}
|
||||
|
||||
wait_for_dask_workers() {
|
||||
[[ "$DASK_WAIT_FOR_WORKERS" == "1" ]] || return 0
|
||||
local expected="$DASK_EXPECTED_WORKERS"
|
||||
if [[ -z "$expected" && -n "$WORKER_HOSTS" ]]; then
|
||||
expected=$(( $(count_hosts "$WORKER_HOSTS") * NWORKERS ))
|
||||
fi
|
||||
if [[ -z "$expected" || "$expected" -le 0 ]]; then
|
||||
return 0
|
||||
fi
|
||||
|
||||
echo "Waiting for Dask workers: expected=$expected timeout=${DASK_WAIT_TIMEOUT}s"
|
||||
"$PYTHON_BIN" - "$DASK_ADDRESS" "$expected" "$DASK_WAIT_TIMEOUT" <<'PY'
|
||||
import sys
|
||||
import time
|
||||
from distributed import Client
|
||||
|
||||
address, expected, timeout = sys.argv[1], int(sys.argv[2]), int(sys.argv[3])
|
||||
deadline = time.time() + timeout
|
||||
client = Client(address)
|
||||
try:
|
||||
while True:
|
||||
info = client.scheduler_info(n_workers=-1)
|
||||
workers = info.get("workers", {})
|
||||
count = len(workers)
|
||||
if count >= expected:
|
||||
print(f"dask_workers_ready count={count} expected={expected}", flush=True)
|
||||
break
|
||||
if time.time() >= deadline:
|
||||
print(
|
||||
f"dask_workers_wait_timeout count={count} expected={expected}",
|
||||
flush=True,
|
||||
)
|
||||
break
|
||||
time.sleep(2)
|
||||
finally:
|
||||
client.close()
|
||||
PY
|
||||
}
|
||||
|
||||
append_mpi_env_args() {
|
||||
[[ "$MPI_EXPORT_ENV" == "1" ]] || return 0
|
||||
mpi_prefix+=(
|
||||
-x "LD_PRELOAD=${LD_PRELOAD:-}"
|
||||
-x "BLIS_NUM_THREADS=$BLIS_NUM_THREADS"
|
||||
-x "OMP_NUM_THREADS=$OMP_NUM_THREADS"
|
||||
-x "MKL_NUM_THREADS=$MKL_NUM_THREADS"
|
||||
-x "OMP_PROC_BIND=$OMP_PROC_BIND"
|
||||
-x "OMP_PLACES=$OMP_PLACES"
|
||||
)
|
||||
}
|
||||
|
||||
build_mpi_prefix() {
|
||||
if [[ -n "$MPIEXEC_FULL" ]]; then
|
||||
# shellcheck disable=SC2206
|
||||
mpi_prefix=($MPIEXEC_FULL)
|
||||
append_mpi_env_args
|
||||
return
|
||||
fi
|
||||
|
||||
local ranks="$MPI_RANKS"
|
||||
if [[ -z "$ranks" && -n "$MPI_HOSTS" ]]; then
|
||||
ranks="$(sum_host_slots "$MPI_HOSTS")"
|
||||
fi
|
||||
if [[ -z "$ranks" ]]; then
|
||||
ranks=2
|
||||
fi
|
||||
|
||||
mpi_prefix=(
|
||||
"$MPIEXEC"
|
||||
--map-by "$MPI_MAP_BY"
|
||||
--bind-to "$MPI_BIND_TO"
|
||||
-np "$ranks"
|
||||
)
|
||||
if [[ "$MPI_REPORT_BINDINGS" == "1" ]]; then
|
||||
mpi_prefix+=(--report-bindings)
|
||||
fi
|
||||
append_mpi_env_args
|
||||
if [[ -n "$MPI_HOSTS" ]]; then
|
||||
mpi_prefix+=(-host "$MPI_HOSTS")
|
||||
elif [[ -n "$MPI_HOSTFILE" ]]; then
|
||||
mpi_prefix+=(-hostfile "$MPI_HOSTFILE")
|
||||
fi
|
||||
}
|
||||
|
||||
is_local_host() {
|
||||
local host="$1"
|
||||
[[ "$host" == "localhost" || "$host" == "127.0.0.1" ]] && return 0
|
||||
@@ -62,25 +202,52 @@ sync_trees_to_hosts() {
|
||||
}
|
||||
|
||||
tools/manage_tn_dask_cluster.sh start
|
||||
DASK_CLUSTER_MANAGED=1
|
||||
wait_for_dask_workers
|
||||
|
||||
echo "Search with dask: $DASK_ADDRESS"
|
||||
"$PYTHON_BIN" -u tools/tn_contest_runner.py search \
|
||||
--case "$CASE" \
|
||||
--nqubits "$NQUBITS" \
|
||||
--nlayers "$NLAYERS" \
|
||||
--observables $OBSERVABLES \
|
||||
--tree-dir "$TREE_DIR" \
|
||||
--dask-address "$DASK_ADDRESS" \
|
||||
--torch-threads "$TORCH_THREADS" \
|
||||
--dtype "$DTYPE" \
|
||||
--tn-search-repeats "$SEARCH_REPEATS" \
|
||||
--tn-search-time "$SEARCH_TIME" \
|
||||
search_args=(
|
||||
--case "$CASE"
|
||||
--nqubits "$NQUBITS"
|
||||
--nlayers "$NLAYERS"
|
||||
--observables $OBSERVABLES
|
||||
--tree-dir "$TREE_DIR"
|
||||
--dask-address "$DASK_ADDRESS"
|
||||
--torch-threads "$TORCH_THREADS"
|
||||
--dtype "$DTYPE"
|
||||
--tn-search-repeats "$SEARCH_REPEATS"
|
||||
--tn-search-time "$SEARCH_TIME"
|
||||
"${tn_slice_args[@]}"
|
||||
)
|
||||
if [[ -n "$DASK_EXPECTED_WORKERS" ]]; then
|
||||
search_args+=(--dask-expected-workers "$DASK_EXPECTED_WORKERS")
|
||||
fi
|
||||
if [[ "$TN_DEBUG_TRIALS" == "1" ]]; then
|
||||
search_args+=(--tn-debug-trials)
|
||||
fi
|
||||
"$PYTHON_BIN" -u tools/tn_contest_runner.py search "${search_args[@]}"
|
||||
|
||||
sync_trees_to_hosts
|
||||
|
||||
echo "Contract with MPI: $MPIEXEC_FULL"
|
||||
read -r -a mpi_prefix <<< "$MPIEXEC_FULL"
|
||||
build_mpi_prefix
|
||||
echo "Contract with MPI: ${mpi_prefix[*]}"
|
||||
if [[ "$TN_CONTRACT_ENV_CHECK" == "1" ]]; then
|
||||
"${mpi_prefix[@]}" "$PYTHON_BIN" -c "from mpi4py import MPI; import os; \
|
||||
import torch; \
|
||||
rank = MPI.COMM_WORLD.Get_rank(); \
|
||||
blis = []; \
|
||||
[blis.append(line.strip().split()[-1]) for line in open('/proc/self/maps') if 'libblis' in line and line.strip().split()[-1] not in blis]; \
|
||||
print('tn_contract_env ' + \
|
||||
f'rank={rank} ' + \
|
||||
f'LD_PRELOAD={os.environ.get(\"LD_PRELOAD\", \"\")} ' + \
|
||||
f'BLIS_NUM_THREADS={os.environ.get(\"BLIS_NUM_THREADS\", \"\")} ' + \
|
||||
f'OMP_NUM_THREADS={os.environ.get(\"OMP_NUM_THREADS\", \"\")} ' + \
|
||||
f'MKL_NUM_THREADS={os.environ.get(\"MKL_NUM_THREADS\", \"\")} ' + \
|
||||
f'OMP_PROC_BIND={os.environ.get(\"OMP_PROC_BIND\", \"\")} ' + \
|
||||
f'OMP_PLACES={os.environ.get(\"OMP_PLACES\", \"\")} ' + \
|
||||
f'torch_threads={torch.get_num_threads()} ' + \
|
||||
f'blis={\";\".join(blis) if blis else \"missing\"}', flush=True)"
|
||||
fi
|
||||
"${mpi_prefix[@]}" "$PYTHON_BIN" -u tools/tn_contest_runner.py contract \
|
||||
--mpi \
|
||||
--case "$CASE" \
|
||||
|
||||
@@ -11,10 +11,15 @@ set -euo pipefail
|
||||
#
|
||||
# Common overrides:
|
||||
# PYTHON_BIN=.venv/bin/python
|
||||
# MPIEXEC=mpiexec
|
||||
# MPIEXEC_FULL="mpirun -np 4 -hostfile /home/yx/qibotn/hostfile -perhost 2"
|
||||
# MPIEXEC=mpirun
|
||||
# MPI_HOSTS="node-1:2,node-2:2,node-3:2,node-0:2"
|
||||
# MPI_RANKS=8
|
||||
# MPI_PE=128
|
||||
# MPI_MAP_BY=ppr:1:numa:PE=128
|
||||
# MPI_BIND_TO=core
|
||||
# MPIEXEC_FULL="mpirun --map-by ppr:1:numa:PE=128 --bind-to core -np 8 -host node-1:2,node-2:2,node-3:2,node-0:2"
|
||||
# HOSTFILE=hostfile # optional; used only if the file exists
|
||||
# RANKS=8
|
||||
# RANKS=8 # fallback if MPI_RANKS is not set
|
||||
# TORCH_THREADS=8
|
||||
# CUT_RATIO=1e-12
|
||||
# OBS_FILTER="boundary_ZZ_q2 ring_xz dense3_spread complex_iZ0"
|
||||
@@ -28,12 +33,23 @@ ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
|
||||
cd "$ROOT_DIR"
|
||||
|
||||
PYTHON_BIN="${PYTHON_BIN:-.venv/bin/python}"
|
||||
MPIEXEC="${MPIEXEC:-mpiexec}"
|
||||
HOSTFILE="${HOSTFILE:-}"
|
||||
MPIEXEC="${MPIEXEC:-mpirun}"
|
||||
MPIEXEC_FULL="${MPIEXEC_FULL:-}"
|
||||
MPI_HOSTS="${MPI_HOSTS:-}"
|
||||
MPI_HOSTFILE="${MPI_HOSTFILE:-${HOSTFILE:-}}"
|
||||
MPI_RANKS="${MPI_RANKS:-${RANKS:-}}"
|
||||
RANKS="${RANKS:-4}"
|
||||
TORCH_THREADS="${TORCH_THREADS:-1}"
|
||||
MPI_PE="${MPI_PE:-$TORCH_THREADS}"
|
||||
MPI_MAP_BY="${MPI_MAP_BY:-ppr:1:numa:PE=$MPI_PE}"
|
||||
MPI_BIND_TO="${MPI_BIND_TO:-core}"
|
||||
MPI_REPORT_BINDINGS="${MPI_REPORT_BINDINGS:-0}"
|
||||
MPI_EXPORT_ENV="${MPI_EXPORT_ENV:-1}"
|
||||
CUT_RATIO="${CUT_RATIO:-1e-12}"
|
||||
OBS_FILTER="${OBS_FILTER:-}"
|
||||
export OMP_NUM_THREADS="${OMP_NUM_THREADS:-$TORCH_THREADS}"
|
||||
export MKL_NUM_THREADS="${MKL_NUM_THREADS:-$TORCH_THREADS}"
|
||||
source "$ROOT_DIR/tools/qibotn_torch_mt_env.sh"
|
||||
|
||||
RUNNER_DIR="$ROOT_DIR/.tmp"
|
||||
mkdir -p "$RUNNER_DIR"
|
||||
@@ -238,15 +254,68 @@ if __name__ == "__main__":
|
||||
main()
|
||||
PY
|
||||
|
||||
if [[ -n "${MPIEXEC_FULL:-}" ]]; then
|
||||
read -r -a mpi_prefix <<< "$MPIEXEC_FULL"
|
||||
else
|
||||
mpi_prefix=("$MPIEXEC")
|
||||
if [[ -n "$HOSTFILE" && -f "$HOSTFILE" ]]; then
|
||||
mpi_prefix+=("-hostfile" "$HOSTFILE")
|
||||
sum_host_slots() {
|
||||
local hosts="$1"
|
||||
local total=0
|
||||
local item slots
|
||||
IFS=',' read -r -a host_items <<< "$hosts"
|
||||
for item in "${host_items[@]}"; do
|
||||
if [[ "$item" == *:* ]]; then
|
||||
slots="${item##*:}"
|
||||
else
|
||||
slots=1
|
||||
fi
|
||||
total=$((total + slots))
|
||||
done
|
||||
echo "$total"
|
||||
}
|
||||
|
||||
append_mpi_env_args() {
|
||||
[[ "$MPI_EXPORT_ENV" == "1" ]] || return 0
|
||||
mpi_prefix+=(
|
||||
-x "LD_PRELOAD=${LD_PRELOAD:-}"
|
||||
-x "BLIS_NUM_THREADS=$BLIS_NUM_THREADS"
|
||||
-x "OMP_NUM_THREADS=$OMP_NUM_THREADS"
|
||||
-x "MKL_NUM_THREADS=$MKL_NUM_THREADS"
|
||||
-x "OMP_PROC_BIND=$OMP_PROC_BIND"
|
||||
-x "OMP_PLACES=$OMP_PLACES"
|
||||
)
|
||||
}
|
||||
|
||||
build_mpi_prefix() {
|
||||
if [[ -n "$MPIEXEC_FULL" ]]; then
|
||||
# shellcheck disable=SC2206
|
||||
mpi_prefix=($MPIEXEC_FULL)
|
||||
append_mpi_env_args
|
||||
return
|
||||
fi
|
||||
mpi_prefix+=("-n" "$RANKS")
|
||||
fi
|
||||
|
||||
local ranks="$MPI_RANKS"
|
||||
if [[ -z "$ranks" && -n "$MPI_HOSTS" ]]; then
|
||||
ranks="$(sum_host_slots "$MPI_HOSTS")"
|
||||
fi
|
||||
if [[ -z "$ranks" ]]; then
|
||||
ranks="$RANKS"
|
||||
fi
|
||||
|
||||
mpi_prefix=(
|
||||
"$MPIEXEC"
|
||||
--map-by "$MPI_MAP_BY"
|
||||
--bind-to "$MPI_BIND_TO"
|
||||
-np "$ranks"
|
||||
)
|
||||
if [[ "$MPI_REPORT_BINDINGS" == "1" ]]; then
|
||||
mpi_prefix+=(--report-bindings)
|
||||
fi
|
||||
append_mpi_env_args
|
||||
if [[ -n "$MPI_HOSTS" ]]; then
|
||||
mpi_prefix+=(-host "$MPI_HOSTS")
|
||||
elif [[ -n "$MPI_HOSTFILE" ]]; then
|
||||
mpi_prefix+=(-hostfile "$MPI_HOSTFILE")
|
||||
fi
|
||||
}
|
||||
|
||||
build_mpi_prefix
|
||||
|
||||
run_case() {
|
||||
local label="$1"
|
||||
@@ -323,7 +392,12 @@ Cases:
|
||||
Common overrides:
|
||||
PYTHON_BIN=.venv/bin/python
|
||||
MPIEXEC=mpiexec
|
||||
MPIEXEC_FULL="mpirun -np 4 -hostfile /home/yx/qibotn/hostfile -perhost 2"
|
||||
MPI_HOSTS="node-1:2,node-2:2,node-3:2,node-0:2"
|
||||
MPI_RANKS=8
|
||||
MPI_PE=128
|
||||
MPI_MAP_BY=ppr:1:numa:PE=128
|
||||
MPI_BIND_TO=core
|
||||
MPIEXEC_FULL="mpirun --map-by ppr:1:numa:PE=128 --bind-to core -np 8 -host node-1:2,node-2:2,node-3:2,node-0:2"
|
||||
HOSTFILE=hostfile
|
||||
RANKS=8
|
||||
TORCH_THREADS=8
|
||||
|
||||
@@ -47,7 +47,7 @@ CASES = {
|
||||
"main1": CaseSpec(
|
||||
circuit_kind="rxx_rzz_chain",
|
||||
observables=("ring_xz",),
|
||||
nqubits=34,
|
||||
nqubits=37,
|
||||
nlayers=20,
|
||||
seed=31001,
|
||||
target_slices=None,
|
||||
@@ -205,6 +205,8 @@ def build_parallel_opts(args, tree_file=None, search_only=False):
|
||||
opts["search_backend"] = args.tn_search_backend
|
||||
if args.dask_address is not None:
|
||||
opts["dask_address"] = args.dask_address
|
||||
if args.dask_expected_workers is not None:
|
||||
opts["dask_expected_workers"] = args.dask_expected_workers
|
||||
if args.dask_close_workers:
|
||||
opts["dask_close_workers"] = True
|
||||
if args.tn_debug_trials:
|
||||
@@ -378,7 +380,7 @@ def main():
|
||||
parser.add_argument("--quimb-backend", choices=("numpy", "torch"), default="torch")
|
||||
parser.add_argument("--dtype", choices=("complex128", "complex64"), default="complex64")
|
||||
parser.add_argument("--tn-target-slices", type=int)
|
||||
parser.add_argument("--tn-target-size", type=int, default=2**32)
|
||||
parser.add_argument("--tn-target-size", type=int, default=2**34)
|
||||
parser.add_argument("--tn-search-workers", type=int)
|
||||
parser.add_argument("--tn-search-repeats", type=int, default=2048)
|
||||
parser.add_argument("--tn-search-time", type=float, default=300.0)
|
||||
@@ -392,6 +394,7 @@ def main():
|
||||
),
|
||||
)
|
||||
parser.add_argument("--dask-address")
|
||||
parser.add_argument("--dask-expected-workers", type=int)
|
||||
parser.add_argument("--dask-close-workers", action="store_true")
|
||||
parser.add_argument(
|
||||
"--keep-dask",
|
||||
|
||||
Reference in New Issue
Block a user