构建基于oneapi的mpi4py,quimb支持mpi多机并行,缩短路径找寻时间
Some checks failed
Build wheels / build (ubuntu-latest, 3.11) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.12) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.13) (push) Has been cancelled
Tests / check (push) Has been cancelled
Tests / build (ubuntu-latest, 3.11) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.12) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.13) (push) Has been cancelled
Some checks failed
Build wheels / build (ubuntu-latest, 3.11) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.12) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.13) (push) Has been cancelled
Tests / check (push) Has been cancelled
Tests / build (ubuntu-latest, 3.11) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.12) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.13) (push) Has been cancelled
This commit is contained in:
68
tests/quimb_mpi.py
Normal file
68
tests/quimb_mpi.py
Normal file
@@ -0,0 +1,68 @@
|
||||
import os
|
||||
import time
|
||||
import numpy as np
|
||||
import quimb.tensor as qtn
|
||||
import cotengra as ctg
|
||||
'''
|
||||
# --- 1. 关键:在导入 numpy/quimb 之前设置环境变量 ---
|
||||
# 告诉底层 BLAS 库 (MKL/OpenBLAS) 使用 96 个线程
|
||||
os.environ["OMP_NUM_THREADS"] = "1"
|
||||
os.environ["MKL_NUM_THREADS"] = "1"
|
||||
os.environ["OPENBLAS_NUM_THREADS"] = "1"
|
||||
# 优化线程亲和性,避免线程在不同 CPU 核心间跳变,提升缓存命中率
|
||||
os.environ["KMP_AFFINITY"] = "granularity=fine,compact,1,0"
|
||||
os.environ["KMP_BLOCKTIME"] = "0"
|
||||
'''
|
||||
# 现在导入库
|
||||
import psutil
|
||||
|
||||
def run_baseline(n_qubits=50, depth=20):
|
||||
print(f"🚀 {n_qubits} Qubits, Depth {depth}")
|
||||
print(f"💻 Detected Logical Cores: {os.cpu_count()}")
|
||||
|
||||
# 1. 构建电路 (必须 complex128 保证精度)
|
||||
circ = qtn.Circuit(n_qubits, dtype=np.complex128)
|
||||
for d in range(depth):
|
||||
for i in range(n_qubits):
|
||||
circ.apply_gate('H', i)
|
||||
for i in range(0, n_qubits - 1, 2):
|
||||
circ.apply_gate('CZ', i, i + 1)
|
||||
|
||||
psi = circ.psi
|
||||
|
||||
# 2. 构建闭合网络 <psi|psi>
|
||||
net = psi.conj() & psi
|
||||
|
||||
# 3. 路径搜索参数 (Kahypar)
|
||||
print("🔍 Searching path with Kahypar...")
|
||||
opt = ctg.HyperOptimizer(
|
||||
methods=['kahypar'],
|
||||
max_repeats=128,
|
||||
parallel=96,
|
||||
minimize='flops',
|
||||
on_trial_error='ignore'
|
||||
)
|
||||
|
||||
# 4. 阶段1:路径搜索
|
||||
t0 = time.perf_counter()
|
||||
tree = net.contraction_tree(optimize=opt)
|
||||
t1 = time.perf_counter()
|
||||
print(f"🔍 Path search done: {t1 - t0:.4f} s")
|
||||
|
||||
# 5. 阶段2:张量收缩
|
||||
result = net.contract(optimize=tree, backend='numpy')
|
||||
t2 = time.perf_counter()
|
||||
peak_mem = psutil.Process().memory_info().rss / 1024**3
|
||||
|
||||
print(f"✅ Done!")
|
||||
print(f"⏱️ Contract: {t2 - t1:.4f} s | Total: {t2 - t0:.4f} s")
|
||||
print(f"💾 Peak Memory: {peak_mem:.2f} GB")
|
||||
print(f"🔢 Result: {result:.10f}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--n_qubits", type=int, default=50)
|
||||
parser.add_argument("--depth", type=int, default=20)
|
||||
args = parser.parse_args()
|
||||
run_baseline(n_qubits=args.n_qubits, depth=args.depth)
|
||||
Reference in New Issue
Block a user