决赛现场脚本
Some checks failed
Build wheels / build (ubuntu-latest, 3.11) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.12) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.13) (push) Has been cancelled
Tests / check (push) Has been cancelled
Tests / build (ubuntu-latest, 3.11) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.12) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.13) (push) Has been cancelled

This commit is contained in:
jaunatisblue
2026-05-18 01:37:19 +08:00
parent 4c7a10d026
commit ef3d7e9ee6
26 changed files with 894 additions and 62 deletions

View File

@@ -16,3 +16,4 @@ Files here are intentionally secondary:
- `benchmark_tn_mpi.py`, `benchmark_search.py`, `benchmark_slice.py`, `benchmark_contract_sliced.py`, `check_tree.py`: old TN path-search/slicing experiments.
- `qibojit_reference_expectation.py`: state-vector reference helper.
- `validate_vidal_mpi_correctness.py`: focused Vidal MPI correctness helper.
- `mpi_torch_thread_probe.py`: MPI + torch OpenMP affinity and threading probe.

View File

@@ -0,0 +1,157 @@
#!/usr/bin/env python
"""Benchmark qredtea/qtealeaves SVD control modes.
This isolates the tensor split used by MPS updates: a rank-2 tensor is split
with singular values contracted either left or right, then reconstructed to
measure numerical error and timing.
"""
from __future__ import annotations
import argparse
import gc
import statistics
import time
import torch
import qmatchatea
from qredtea.torchapi import QteaTorchTensor
def _dtype(name: str):
return {
"complex64": torch.complex64,
"complex128": torch.complex128,
"float64": torch.float64,
"float32": torch.float32,
}[name]
def _random_matrix(shape, dtype, seed):
gen = torch.Generator(device="cpu")
gen.manual_seed(seed)
if dtype.is_complex:
real_dtype = torch.float32 if dtype == torch.complex64 else torch.float64
real = torch.randn(shape, dtype=real_dtype, generator=gen)
imag = torch.randn(shape, dtype=real_dtype, generator=gen)
return torch.complex(real, imag).to(dtype)
return torch.randn(shape, dtype=dtype, generator=gen)
def _sync():
if torch.cuda.is_available():
torch.cuda.synchronize()
def run_one(matrix, ctrl, max_bond, contract_singvals, repeats):
conv = qmatchatea.QCConvergenceParameters(
max_bond_dimension=max_bond,
cut_ratio=0.0,
svd_ctrl=ctrl,
)
qtensor = QteaTorchTensor.from_elem_array(matrix, dtype=matrix.dtype, device="cpu")
times = []
rel_error = None
kept = None
status = "ok"
error = ""
for i in range(repeats):
gc.collect()
_sync()
t0 = time.perf_counter()
try:
left, right, singvals, _ = qtensor.split_svd(
[0],
[1],
contract_singvals=contract_singvals,
conv_params=conv,
)
except Exception as exc: # noqa: BLE001 - benchmark should keep going
status = "error"
error = repr(exc)
break
_sync()
times.append(time.perf_counter() - t0)
if i == repeats - 1:
left_matrix = left.elem.reshape(matrix.shape[0], -1)
right_matrix = right.elem.reshape(-1, matrix.shape[1])
recon = left_matrix @ right_matrix
rel_error = (
torch.linalg.vector_norm(matrix - recon)
/ torch.linalg.vector_norm(matrix)
).item()
kept = int(singvals.numel())
return {
"ctrl": ctrl,
"contract_singvals": contract_singvals,
"status": status,
"median_ms": float("nan") if not times else statistics.median(times) * 1000,
"min_ms": float("nan") if not times else min(times) * 1000,
"rel_error": rel_error,
"kept": kept,
"error": error,
}
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--shapes", nargs="+", default=("256x1024", "1024x256", "512x512"))
parser.add_argument("--max-bond", type=int, default=128)
parser.add_argument("--dtype", choices=("complex64", "complex128", "float32", "float64"), default="complex128")
parser.add_argument("--threads", type=int, default=8)
parser.add_argument("--repeats", type=int, default=3)
parser.add_argument(
"--controls",
nargs="+",
default=("A", "D", "V", "R", "E", "E!", "X", "X!"),
)
args = parser.parse_args()
torch.set_num_threads(args.threads)
dtype = _dtype(args.dtype)
print(
"svd_benchmark "
f"dtype={args.dtype} threads={torch.get_num_threads()} "
f"max_bond={args.max_bond} repeats={args.repeats}",
flush=True,
)
print(
"columns shape contract ctrl status median_ms min_ms kept rel_error error",
flush=True,
)
for shape_text in args.shapes:
m_text, n_text = shape_text.lower().split("x", 1)
shape = (int(m_text), int(n_text))
matrix = _random_matrix(shape, dtype, seed=sum(shape))
for contract_singvals in ("L", "R"):
for ctrl in args.controls:
result = run_one(
matrix,
ctrl=ctrl,
max_bond=args.max_bond,
contract_singvals=contract_singvals,
repeats=args.repeats,
)
print(
f"row shape={shape_text} "
f"contract={contract_singvals} "
f"ctrl={ctrl} "
f"status={result['status']} "
f"median_ms={result['median_ms']:.3f} "
f"min_ms={result['min_ms']:.3f} "
f"kept={result['kept']} "
f"rel_error={result['rel_error']} "
f"error={result['error']}",
flush=True,
)
if __name__ == "__main__":
main()

View File

@@ -17,10 +17,10 @@ set -euo pipefail
# WORKER_HOSTS="10.20.1.103 10.20.6.101"
# NWORKERS=48
# NTHREADS=1
# ROOT_DIR=/home/yx/qibotn
# ROOT_DIR=/home/qibo/qibotn
# PYTHON_BIN=.venv/bin/python
ROOT_DIR="${ROOT_DIR:-/home/yx/qibotn}"
ROOT_DIR="${ROOT_DIR:-/home/qibo/qibotn}"
PYTHON_BIN="${PYTHON_BIN:-.venv/bin/python}"
SCHEDULER_HOST="${SCHEDULER_HOST:-10.20.1.103}"
SCHEDULER_PORT="${SCHEDULER_PORT:-8786}"

View File

@@ -0,0 +1,182 @@
#!/usr/bin/env python
"""Probe MPI rank placement and whether torch CPU ops use multiple threads.
Run this under mpirun/mpiexec to check:
* which CPUs each rank is allowed to run on,
* whether torch sees the requested intra-op thread count, and
* whether a large CPU tensor op actually consumes more CPU time than wall time.
The script is intentionally small and self-contained so it can be used to debug
MPI launcher affinity and torch OpenMP behavior independently from the TN code
path.
"""
from __future__ import annotations
import argparse
import os
import socket
import time
from pathlib import Path
from mpi4py import MPI
def _dtype_from_name(name):
import torch
mapping = {
"float32": torch.float32,
"float64": torch.float64,
"complex64": torch.complex64,
"complex128": torch.complex128,
}
return mapping[name]
def _make_tensor(shape, dtype):
import torch
if dtype in (torch.complex64, torch.complex128):
base = torch.float32 if dtype == torch.complex64 else torch.float64
return torch.complex(
torch.randn(shape, dtype=base),
torch.randn(shape, dtype=base),
)
return torch.randn(shape, dtype=dtype)
def _bench(label, fn, iters, warmup=2):
for _ in range(warmup):
fn()
start_wall = time.perf_counter()
start_cpu = time.process_time()
checksum = 0.0
for _ in range(iters):
value = fn()
checksum += float(value)
wall = time.perf_counter() - start_wall
cpu = time.process_time() - start_cpu
ratio = cpu / wall if wall > 0 else float("inf")
print(
f"{label} wall={wall:.3f}s cpu={cpu:.3f}s cpu_over_wall={ratio:.2f} "
f"checksum={checksum:.6e}",
flush=True,
)
def _visible_numa_nodes():
nodes = []
for path in sorted(Path("/sys/devices/system/node").glob("node[0-9]*")):
cpulist = path / "cpulist"
if cpulist.exists():
nodes.append(f"{path.name}:{cpulist.read_text(encoding='utf-8').strip()}")
return ",".join(nodes) if nodes else "unknown"
def _dtype_nbytes(name):
return {
"float32": 4,
"float64": 8,
"complex64": 8,
"complex128": 16,
}[name]
def _format_gib(nbytes):
return f"{nbytes / (1024 ** 3):.2f}GiB"
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--threads", type=int, default=48)
parser.add_argument("--n", type=int, default=4096)
parser.add_argument("--iters", type=int, default=4)
parser.add_argument("--dtype", choices=("float32", "float64", "complex64", "complex128"), default="float32")
parser.add_argument("--op", choices=("matmul", "tensordot", "both"), default="both")
parser.add_argument(
"--affinity-only",
action="store_true",
help="Print MPI/torch placement diagnostics without allocating tensors.",
)
args = parser.parse_args()
os.environ.setdefault("OMP_NUM_THREADS", str(args.threads))
os.environ.setdefault("MKL_NUM_THREADS", str(args.threads))
os.environ.setdefault("OMP_PROC_BIND", "close")
os.environ.setdefault("OMP_PLACES", "cores")
import torch
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
size = comm.Get_size()
torch.set_num_threads(args.threads)
try:
torch.set_num_interop_threads(1)
except Exception:
pass
dtype = _dtype_from_name(args.dtype)
affinity = sorted(os.sched_getaffinity(0))
allowed_list = ""
try:
with open("/proc/self/status", encoding="utf-8") as f:
for line in f:
if line.startswith("Cpus_allowed_list:"):
allowed_list = line.split(":", 1)[1].strip()
break
except OSError:
pass
print(
f"rank={rank}/{size} host={socket.gethostname()} pid={os.getpid()} "
f"affinity_len={len(affinity)} allowed={allowed_list} "
f"torch_threads={torch.get_num_threads()} "
f"torch_interop={torch.get_num_interop_threads()} "
f"OMP_NUM_THREADS={os.environ.get('OMP_NUM_THREADS')} "
f"MKL_NUM_THREADS={os.environ.get('MKL_NUM_THREADS')} "
f"OMP_PROC_BIND={os.environ.get('OMP_PROC_BIND')} "
f"OMP_PLACES={os.environ.get('OMP_PLACES')} "
f"visible_numa={_visible_numa_nodes()}",
flush=True,
)
if rank == 0:
print(torch.__config__.parallel_info(), flush=True)
input_bytes = args.n * args.n * _dtype_nbytes(args.dtype)
min_live_bytes = 3 * input_bytes
print(
f"matrix_n={args.n} dtype={args.dtype} "
f"one_matrix={_format_gib(input_bytes)} "
f"approx_min_live_per_rank={_format_gib(min_live_bytes)} "
f"approx_min_live_all_ranks={_format_gib(min_live_bytes * size)}",
flush=True,
)
comm.Barrier()
if args.affinity_only:
return
a = _make_tensor((args.n, args.n), dtype)
b = _make_tensor((args.n, args.n), dtype)
def run_matmul():
value = (a @ b).sum()
return value.real.item() if value.is_complex() else value.item()
def run_tensordot():
value = torch.tensordot(a, b, dims=1)
value = value.sum()
return value.real.item() if value.is_complex() else value.item()
if args.op in ("matmul", "both"):
_bench("matmul", run_matmul, args.iters)
if args.op in ("tensordot", "both"):
_bench("tensordot", run_tensordot, args.iters)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,22 @@
#!/usr/bin/env bash
# Shared runtime setup for CPU torch TN/MPS runs.
#
# This makes AOCL BLIS use the multithreaded library when available, which is
# required for complex64 tensordot/cgemm to actually use all cores on this host.
QIBOTN_BLIS_MT="${QIBOTN_BLIS_MT:-/home/aocc/aocl/5.2.0/aocc/lib_LP64/libblis-mt.so.5}"
export BLIS_NUM_THREADS="${BLIS_NUM_THREADS:-${OMP_NUM_THREADS:-1}}"
if [[ -f "$QIBOTN_BLIS_MT" ]]; then
case ":${LD_PRELOAD:-}:" in
*":$QIBOTN_BLIS_MT:"*)
;;
*)
export LD_PRELOAD="${LD_PRELOAD:+$LD_PRELOAD:}$QIBOTN_BLIS_MT"
;;
esac
fi
export OMP_PROC_BIND="${OMP_PROC_BIND:-close}"
export OMP_PLACES="${OMP_PLACES:-cores}"

View File

@@ -21,6 +21,7 @@ TN_THREADS="${TN_THREADS:-8}"
export OMP_NUM_THREADS="${OMP_NUM_THREADS:-1}"
export MKL_NUM_THREADS="${MKL_NUM_THREADS:-1}"
source "$ROOT_DIR/tools/qibotn_torch_mt_env.sh"
run_mpi() {
local ranks="$1"

View File

@@ -22,6 +22,7 @@ TN_THREADS="${TN_THREADS:-12}"
export OMP_NUM_THREADS="${OMP_NUM_THREADS:-1}"
export MKL_NUM_THREADS="${MKL_NUM_THREADS:-1}"
source "$ROOT_DIR/tools/qibotn_torch_mt_env.sh"
estimate_mps_memory() {
local nqubits="$1"

View File

@@ -11,25 +11,165 @@ NLAYERS="${NLAYERS:-20}"
TORCH_THREADS="${TORCH_THREADS:-48}"
SEARCH_REPEATS="${SEARCH_REPEATS:-2048}"
SEARCH_TIME="${SEARCH_TIME:-300}"
TN_TARGET_SIZE="${TN_TARGET_SIZE:-8589934592}"
TN_TARGET_SIZE="${TN_TARGET_SIZE:-17179869184}"
TN_TARGET_SLICES="${TN_TARGET_SLICES:-}"
PYTHON_BIN="${PYTHON_BIN:-.venv/bin/python}"
DTYPE="${DTYPE:-complex64}"
TREE_DIR="${TREE_DIR:-trees/contest_tn}"
DASK_ADDRESS="${DASK_ADDRESS:-tcp://10.20.1.103:8786}"
MPIEXEC_FULL="${MPIEXEC_FULL:-mpirun -np 4 -hostfile /home/yx/qibotn/hostfile -perhost 2}"
DASK_EXPECTED_WORKERS="${DASK_EXPECTED_WORKERS:-}"
DASK_WAIT_FOR_WORKERS="${DASK_WAIT_FOR_WORKERS:-1}"
DASK_WAIT_TIMEOUT="${DASK_WAIT_TIMEOUT:-600}"
TN_DEBUG_TRIALS="${TN_DEBUG_TRIALS:-0}"
MPIEXEC="${MPIEXEC:-mpirun}"
MPIEXEC_FULL="${MPIEXEC_FULL:-}"
MPI_HOSTS="${MPI_HOSTS:-}"
MPI_HOSTFILE="${MPI_HOSTFILE:-${HOSTFILE:-}}"
MPI_RANKS="${MPI_RANKS:-}"
MPI_PE="${MPI_PE:-$TORCH_THREADS}"
MPI_MAP_BY="${MPI_MAP_BY:-ppr:1:numa:PE=$MPI_PE}"
MPI_BIND_TO="${MPI_BIND_TO:-core}"
MPI_REPORT_BINDINGS="${MPI_REPORT_BINDINGS:-0}"
MPI_EXPORT_ENV="${MPI_EXPORT_ENV:-1}"
TN_CONTRACT_ENV_CHECK="${TN_CONTRACT_ENV_CHECK:-1}"
SYNC_TREES="${SYNC_TREES:-1}"
SYNC_HOSTS="${SYNC_HOSTS:-${WORKER_HOSTS:-}}"
SSH_BIN="${SSH_BIN:-ssh}"
DASK_CLUSTER_MANAGED="${DASK_CLUSTER_MANAGED:-0}"
export TCM_ENABLE="${TCM_ENABLE:-1}"
export OMP_NUM_THREADS="${OMP_NUM_THREADS:-$TORCH_THREADS}"
export MKL_NUM_THREADS="${MKL_NUM_THREADS:-$TORCH_THREADS}"
source "$ROOT_DIR/tools/qibotn_torch_mt_env.sh"
tn_slice_args=(--tn-target-size "$TN_TARGET_SIZE")
if [[ -n "$TN_TARGET_SLICES" ]]; then
tn_slice_args+=(--tn-target-slices "$TN_TARGET_SLICES")
fi
cleanup_dask_cluster() {
local status=$?
if [[ "$DASK_CLUSTER_MANAGED" == "1" ]]; then
set +e
tools/manage_tn_dask_cluster.sh stop >/dev/null 2>&1 || true
fi
exit "$status"
}
trap cleanup_dask_cluster EXIT INT TERM HUP
sum_host_slots() {
local hosts="$1"
local total=0
local item slots
IFS=',' read -r -a host_items <<< "$hosts"
for item in "${host_items[@]}"; do
if [[ "$item" == *:* ]]; then
slots="${item##*:}"
else
slots=1
fi
total=$((total + slots))
done
echo "$total"
}
count_hosts() {
local hosts="$1"
local count=0
local item
IFS=' ' read -r -a host_items <<< "$hosts"
for item in "${host_items[@]}"; do
[[ -n "$item" ]] && count=$((count + 1))
done
echo "$count"
}
wait_for_dask_workers() {
[[ "$DASK_WAIT_FOR_WORKERS" == "1" ]] || return 0
local expected="$DASK_EXPECTED_WORKERS"
if [[ -z "$expected" && -n "$WORKER_HOSTS" ]]; then
expected=$(( $(count_hosts "$WORKER_HOSTS") * NWORKERS ))
fi
if [[ -z "$expected" || "$expected" -le 0 ]]; then
return 0
fi
echo "Waiting for Dask workers: expected=$expected timeout=${DASK_WAIT_TIMEOUT}s"
"$PYTHON_BIN" - "$DASK_ADDRESS" "$expected" "$DASK_WAIT_TIMEOUT" <<'PY'
import sys
import time
from distributed import Client
address, expected, timeout = sys.argv[1], int(sys.argv[2]), int(sys.argv[3])
deadline = time.time() + timeout
client = Client(address)
try:
while True:
info = client.scheduler_info(n_workers=-1)
workers = info.get("workers", {})
count = len(workers)
if count >= expected:
print(f"dask_workers_ready count={count} expected={expected}", flush=True)
break
if time.time() >= deadline:
print(
f"dask_workers_wait_timeout count={count} expected={expected}",
flush=True,
)
break
time.sleep(2)
finally:
client.close()
PY
}
append_mpi_env_args() {
[[ "$MPI_EXPORT_ENV" == "1" ]] || return 0
mpi_prefix+=(
-x "LD_PRELOAD=${LD_PRELOAD:-}"
-x "BLIS_NUM_THREADS=$BLIS_NUM_THREADS"
-x "OMP_NUM_THREADS=$OMP_NUM_THREADS"
-x "MKL_NUM_THREADS=$MKL_NUM_THREADS"
-x "OMP_PROC_BIND=$OMP_PROC_BIND"
-x "OMP_PLACES=$OMP_PLACES"
)
}
build_mpi_prefix() {
if [[ -n "$MPIEXEC_FULL" ]]; then
# shellcheck disable=SC2206
mpi_prefix=($MPIEXEC_FULL)
append_mpi_env_args
return
fi
local ranks="$MPI_RANKS"
if [[ -z "$ranks" && -n "$MPI_HOSTS" ]]; then
ranks="$(sum_host_slots "$MPI_HOSTS")"
fi
if [[ -z "$ranks" ]]; then
ranks=2
fi
mpi_prefix=(
"$MPIEXEC"
--map-by "$MPI_MAP_BY"
--bind-to "$MPI_BIND_TO"
-np "$ranks"
)
if [[ "$MPI_REPORT_BINDINGS" == "1" ]]; then
mpi_prefix+=(--report-bindings)
fi
append_mpi_env_args
if [[ -n "$MPI_HOSTS" ]]; then
mpi_prefix+=(-host "$MPI_HOSTS")
elif [[ -n "$MPI_HOSTFILE" ]]; then
mpi_prefix+=(-hostfile "$MPI_HOSTFILE")
fi
}
is_local_host() {
local host="$1"
[[ "$host" == "localhost" || "$host" == "127.0.0.1" ]] && return 0
@@ -62,25 +202,52 @@ sync_trees_to_hosts() {
}
tools/manage_tn_dask_cluster.sh start
DASK_CLUSTER_MANAGED=1
wait_for_dask_workers
echo "Search with dask: $DASK_ADDRESS"
"$PYTHON_BIN" -u tools/tn_contest_runner.py search \
--case "$CASE" \
--nqubits "$NQUBITS" \
--nlayers "$NLAYERS" \
--observables $OBSERVABLES \
--tree-dir "$TREE_DIR" \
--dask-address "$DASK_ADDRESS" \
--torch-threads "$TORCH_THREADS" \
--dtype "$DTYPE" \
--tn-search-repeats "$SEARCH_REPEATS" \
--tn-search-time "$SEARCH_TIME" \
search_args=(
--case "$CASE"
--nqubits "$NQUBITS"
--nlayers "$NLAYERS"
--observables $OBSERVABLES
--tree-dir "$TREE_DIR"
--dask-address "$DASK_ADDRESS"
--torch-threads "$TORCH_THREADS"
--dtype "$DTYPE"
--tn-search-repeats "$SEARCH_REPEATS"
--tn-search-time "$SEARCH_TIME"
"${tn_slice_args[@]}"
)
if [[ -n "$DASK_EXPECTED_WORKERS" ]]; then
search_args+=(--dask-expected-workers "$DASK_EXPECTED_WORKERS")
fi
if [[ "$TN_DEBUG_TRIALS" == "1" ]]; then
search_args+=(--tn-debug-trials)
fi
"$PYTHON_BIN" -u tools/tn_contest_runner.py search "${search_args[@]}"
sync_trees_to_hosts
echo "Contract with MPI: $MPIEXEC_FULL"
read -r -a mpi_prefix <<< "$MPIEXEC_FULL"
build_mpi_prefix
echo "Contract with MPI: ${mpi_prefix[*]}"
if [[ "$TN_CONTRACT_ENV_CHECK" == "1" ]]; then
"${mpi_prefix[@]}" "$PYTHON_BIN" -c "from mpi4py import MPI; import os; \
import torch; \
rank = MPI.COMM_WORLD.Get_rank(); \
blis = []; \
[blis.append(line.strip().split()[-1]) for line in open('/proc/self/maps') if 'libblis' in line and line.strip().split()[-1] not in blis]; \
print('tn_contract_env ' + \
f'rank={rank} ' + \
f'LD_PRELOAD={os.environ.get(\"LD_PRELOAD\", \"\")} ' + \
f'BLIS_NUM_THREADS={os.environ.get(\"BLIS_NUM_THREADS\", \"\")} ' + \
f'OMP_NUM_THREADS={os.environ.get(\"OMP_NUM_THREADS\", \"\")} ' + \
f'MKL_NUM_THREADS={os.environ.get(\"MKL_NUM_THREADS\", \"\")} ' + \
f'OMP_PROC_BIND={os.environ.get(\"OMP_PROC_BIND\", \"\")} ' + \
f'OMP_PLACES={os.environ.get(\"OMP_PLACES\", \"\")} ' + \
f'torch_threads={torch.get_num_threads()} ' + \
f'blis={\";\".join(blis) if blis else \"missing\"}', flush=True)"
fi
"${mpi_prefix[@]}" "$PYTHON_BIN" -u tools/tn_contest_runner.py contract \
--mpi \
--case "$CASE" \

View File

@@ -11,10 +11,15 @@ set -euo pipefail
#
# Common overrides:
# PYTHON_BIN=.venv/bin/python
# MPIEXEC=mpiexec
# MPIEXEC_FULL="mpirun -np 4 -hostfile /home/yx/qibotn/hostfile -perhost 2"
# MPIEXEC=mpirun
# MPI_HOSTS="node-1:2,node-2:2,node-3:2,node-0:2"
# MPI_RANKS=8
# MPI_PE=128
# MPI_MAP_BY=ppr:1:numa:PE=128
# MPI_BIND_TO=core
# MPIEXEC_FULL="mpirun --map-by ppr:1:numa:PE=128 --bind-to core -np 8 -host node-1:2,node-2:2,node-3:2,node-0:2"
# HOSTFILE=hostfile # optional; used only if the file exists
# RANKS=8
# RANKS=8 # fallback if MPI_RANKS is not set
# TORCH_THREADS=8
# CUT_RATIO=1e-12
# OBS_FILTER="boundary_ZZ_q2 ring_xz dense3_spread complex_iZ0"
@@ -28,12 +33,23 @@ ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
cd "$ROOT_DIR"
PYTHON_BIN="${PYTHON_BIN:-.venv/bin/python}"
MPIEXEC="${MPIEXEC:-mpiexec}"
HOSTFILE="${HOSTFILE:-}"
MPIEXEC="${MPIEXEC:-mpirun}"
MPIEXEC_FULL="${MPIEXEC_FULL:-}"
MPI_HOSTS="${MPI_HOSTS:-}"
MPI_HOSTFILE="${MPI_HOSTFILE:-${HOSTFILE:-}}"
MPI_RANKS="${MPI_RANKS:-${RANKS:-}}"
RANKS="${RANKS:-4}"
TORCH_THREADS="${TORCH_THREADS:-1}"
MPI_PE="${MPI_PE:-$TORCH_THREADS}"
MPI_MAP_BY="${MPI_MAP_BY:-ppr:1:numa:PE=$MPI_PE}"
MPI_BIND_TO="${MPI_BIND_TO:-core}"
MPI_REPORT_BINDINGS="${MPI_REPORT_BINDINGS:-0}"
MPI_EXPORT_ENV="${MPI_EXPORT_ENV:-1}"
CUT_RATIO="${CUT_RATIO:-1e-12}"
OBS_FILTER="${OBS_FILTER:-}"
export OMP_NUM_THREADS="${OMP_NUM_THREADS:-$TORCH_THREADS}"
export MKL_NUM_THREADS="${MKL_NUM_THREADS:-$TORCH_THREADS}"
source "$ROOT_DIR/tools/qibotn_torch_mt_env.sh"
RUNNER_DIR="$ROOT_DIR/.tmp"
mkdir -p "$RUNNER_DIR"
@@ -238,15 +254,68 @@ if __name__ == "__main__":
main()
PY
if [[ -n "${MPIEXEC_FULL:-}" ]]; then
read -r -a mpi_prefix <<< "$MPIEXEC_FULL"
else
mpi_prefix=("$MPIEXEC")
if [[ -n "$HOSTFILE" && -f "$HOSTFILE" ]]; then
mpi_prefix+=("-hostfile" "$HOSTFILE")
sum_host_slots() {
local hosts="$1"
local total=0
local item slots
IFS=',' read -r -a host_items <<< "$hosts"
for item in "${host_items[@]}"; do
if [[ "$item" == *:* ]]; then
slots="${item##*:}"
else
slots=1
fi
total=$((total + slots))
done
echo "$total"
}
append_mpi_env_args() {
[[ "$MPI_EXPORT_ENV" == "1" ]] || return 0
mpi_prefix+=(
-x "LD_PRELOAD=${LD_PRELOAD:-}"
-x "BLIS_NUM_THREADS=$BLIS_NUM_THREADS"
-x "OMP_NUM_THREADS=$OMP_NUM_THREADS"
-x "MKL_NUM_THREADS=$MKL_NUM_THREADS"
-x "OMP_PROC_BIND=$OMP_PROC_BIND"
-x "OMP_PLACES=$OMP_PLACES"
)
}
build_mpi_prefix() {
if [[ -n "$MPIEXEC_FULL" ]]; then
# shellcheck disable=SC2206
mpi_prefix=($MPIEXEC_FULL)
append_mpi_env_args
return
fi
mpi_prefix+=("-n" "$RANKS")
fi
local ranks="$MPI_RANKS"
if [[ -z "$ranks" && -n "$MPI_HOSTS" ]]; then
ranks="$(sum_host_slots "$MPI_HOSTS")"
fi
if [[ -z "$ranks" ]]; then
ranks="$RANKS"
fi
mpi_prefix=(
"$MPIEXEC"
--map-by "$MPI_MAP_BY"
--bind-to "$MPI_BIND_TO"
-np "$ranks"
)
if [[ "$MPI_REPORT_BINDINGS" == "1" ]]; then
mpi_prefix+=(--report-bindings)
fi
append_mpi_env_args
if [[ -n "$MPI_HOSTS" ]]; then
mpi_prefix+=(-host "$MPI_HOSTS")
elif [[ -n "$MPI_HOSTFILE" ]]; then
mpi_prefix+=(-hostfile "$MPI_HOSTFILE")
fi
}
build_mpi_prefix
run_case() {
local label="$1"
@@ -323,7 +392,12 @@ Cases:
Common overrides:
PYTHON_BIN=.venv/bin/python
MPIEXEC=mpiexec
MPIEXEC_FULL="mpirun -np 4 -hostfile /home/yx/qibotn/hostfile -perhost 2"
MPI_HOSTS="node-1:2,node-2:2,node-3:2,node-0:2"
MPI_RANKS=8
MPI_PE=128
MPI_MAP_BY=ppr:1:numa:PE=128
MPI_BIND_TO=core
MPIEXEC_FULL="mpirun --map-by ppr:1:numa:PE=128 --bind-to core -np 8 -host node-1:2,node-2:2,node-3:2,node-0:2"
HOSTFILE=hostfile
RANKS=8
TORCH_THREADS=8

View File

@@ -47,7 +47,7 @@ CASES = {
"main1": CaseSpec(
circuit_kind="rxx_rzz_chain",
observables=("ring_xz",),
nqubits=34,
nqubits=37,
nlayers=20,
seed=31001,
target_slices=None,
@@ -205,6 +205,8 @@ def build_parallel_opts(args, tree_file=None, search_only=False):
opts["search_backend"] = args.tn_search_backend
if args.dask_address is not None:
opts["dask_address"] = args.dask_address
if args.dask_expected_workers is not None:
opts["dask_expected_workers"] = args.dask_expected_workers
if args.dask_close_workers:
opts["dask_close_workers"] = True
if args.tn_debug_trials:
@@ -378,7 +380,7 @@ def main():
parser.add_argument("--quimb-backend", choices=("numpy", "torch"), default="torch")
parser.add_argument("--dtype", choices=("complex128", "complex64"), default="complex64")
parser.add_argument("--tn-target-slices", type=int)
parser.add_argument("--tn-target-size", type=int, default=2**32)
parser.add_argument("--tn-target-size", type=int, default=2**34)
parser.add_argument("--tn-search-workers", type=int)
parser.add_argument("--tn-search-repeats", type=int, default=2048)
parser.add_argument("--tn-search-time", type=float, default=300.0)
@@ -392,6 +394,7 @@ def main():
),
)
parser.add_argument("--dask-address")
parser.add_argument("--dask-expected-workers", type=int)
parser.add_argument("--dask-close-workers", action="store_true")
parser.add_argument(
"--keep-dask",