Compare commits

...

2 Commits

Author SHA1 Message Date
5479574502 优化 torch CPU 张量网络收缩路径
Some checks failed
Build wheels / build (ubuntu-latest, 3.11) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.12) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.13) (push) Has been cancelled
Tests / check (push) Has been cancelled
Tests / build (ubuntu-latest, 3.11) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.12) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.13) (push) Has been cancelled
- torch CPU 收缩默认走 cotengra matmul lowering
  - 复用 mm/bmm/matmul 输出缓冲区,降低中间张量分配压力
  - 仅回收 contiguous tensor,避免非连续 view 进入 workspace
  - 调整 cotengra 中间节点 index 顺序,减少 reshape 触发 clone/copy
  - qibotn MPI 分片收缩显式使用 backend=torch
  - rank 内分片结果先在 torch 中累加,最后再转 numpy 做 Reduce
  - 统一 quimb 后端 torch 数组转换为 CPU contiguous complex128
2026-05-08 11:57:18 +08:00
cec0ba272a 完善脚本功能,添加计时估计 2026-05-08 10:20:03 +08:00
5 changed files with 92 additions and 18 deletions

View File

@@ -31,13 +31,13 @@ else:
tree = None
tree = comm.bcast(tree, root=0)
arrays = [torch.from_numpy(np.ascontiguousarray(t._data, dtype=np.complex128)) for t in tn.tensors]
arrays = [torch.from_numpy(np.asarray(t._data)) for t in tn.tensors]
n_slices = tree.multiplicity
if rank == 0:
print(f"Slices: {n_slices}, Ranks: {size}, "
f"Peak: {tree.max_size() * 16 / 1e9:.2f} GB, "
f"Threads/rank: {max(1, NCORES // size)}, Backend: torch")
f"Threads/rank: {NCORES}, Backend: torch")
t0 = time.time()
result = None

View File

@@ -8,7 +8,7 @@ with open(f"data/tree_q{NQUBITS}_l{NLAYERS}.pkl", 'rb') as f:
print(f"Original peak: {tree.max_size() * 16 / 1e9:.2f} GB")
tree_sliced = tree.slice_and_reconfigure(target_size=2**30) # 2^29 = 8 GB
tree_sliced = tree.slice_and_reconfigure(target_size=2**28)
with open(f"data/tree_q{NQUBITS}_l{NLAYERS}_sliced.pkl", 'wb') as f:
pickle.dump(tree_sliced, f)

View File

@@ -5,9 +5,21 @@ path = sys.argv[1] if len(sys.argv) > 1 else "data/tree_q25_l10.pkl"
with open(path, 'rb') as f:
tree = pickle.load(f)
# Intel 8558P: 96 cores, 2.1GHz, AVX-512 (16 FP64/cycle), FMA x2
# complex128 multiply-add = 6 real FLOPs
CORES = 96
FREQ = 2.1e9
AVX512_FP64 = 16
TFLOPS = CORES * FREQ * AVX512_FP64 * 2 / 1e12 # ~6.45 TFLOPS real FP64
COMPLEX_FLOPS = TFLOPS / 6 # complex128 effective
flops = tree.total_flops()
slices = tree.multiplicity
est_seconds = flops * slices / (COMPLEX_FLOPS * 1e12)
print(f"File: {path}")
print(f"Peak memory elements: {tree.max_size():.2e}")
print(f"Peak memory (GB): {tree.max_size() * 16 / 1e9:.2f}") # complex128 = 16 bytes
print(f"Total FLOPs: {tree.total_flops():.2e}")
print(f"Peak memory (GB): {tree.max_size() * 16 / 1e9:.2f}")
print(f"Total FLOPs: {flops:.2e} x{slices} slices = {flops*slices:.2e}")
print(f"Contraction width: {tree.contraction_width()}")
print(f"Multiplicity (slices): {tree.multiplicity}")
print(f"Multiplicity (slices): {slices}")
print(f"Estimated time (96 cores): {est_seconds:.1f}s ({est_seconds/3600:.2f}h)")

View File

@@ -38,6 +38,36 @@ GATE_MAP = {
}
def _torch_cpu_array(data, dtype=None):
"""Convert array-like data to a contiguous CPU torch tensor."""
import numpy as np
import torch
if isinstance(data, torch.Tensor):
x = data
else:
array = np.asarray(data)
if any(stride < 0 for stride in array.strides):
array = np.ascontiguousarray(array)
x = torch.from_numpy(array)
if x.device.type != "cpu":
x = x.cpu()
if dtype is not None and x.dtype != dtype:
x = x.to(dtype)
if not x.is_contiguous():
x = x.contiguous()
return x
def _arrays_to_backend(arrays, backend, engine):
if backend == "torch":
import torch
return [_torch_cpu_array(array, dtype=torch.complex128) for array in arrays]
return [engine.asarray(array) for array in arrays]
def __init__(self, quimb_backend="numpy", contraction_optimizer="auto-hq"):
super(self.__class__, self).__init__()
@@ -439,6 +469,7 @@ def _expectation_parallel(self, circuit, observable, method, opts):
mpi_contract = opts.get('mpi_contract', False)
torch_threads = opts.get('torch_threads', None)
slicing_opts = opts.get('slicing_opts', None)
trial_timeout = opts.get('trial_timeout', None)
qc = self._qibo_circuit_to_quimb(
circuit,
@@ -472,18 +503,38 @@ def _expectation_parallel(self, circuit, observable, method, opts):
max_time=max_time,
n_workers=search_workers,
slicing_opts=slicing_opts,
trial_timeout=trial_timeout,
)
if tree is None:
continue
if mpi_contract and comm and size > 1:
arrays = [self.engine.asarray(a) for a in tn.arrays]
arrays = _arrays_to_backend(tn.arrays, self.backend, self.engine)
val = parallel_contract(tree, arrays, method='mpi', comm=comm)
else:
for tensor in tn.tensors:
tensor._data = torch.from_numpy(self.engine.asarray(tensor._data)).to(torch.complex128)
val = complex(tn.contract(all, output_inds=(), optimize=tree))
if self.backend == "torch":
for tensor in tn.tensors:
tensor._data = _torch_cpu_array(
tensor._data, dtype=torch.complex128
)
val = complex(
tn.contract(
all,
output_inds=(),
optimize=tree,
backend="torch",
)
)
else:
val = complex(
tn.contract(
all,
output_inds=(),
optimize=tree,
backend=self.backend,
)
)
my_exp += coeff * complex(val)

View File

@@ -170,14 +170,25 @@ def _contract_mpi(tree, arrays, comm, root=0):
rank, size = comm.Get_rank(), comm.Get_size()
is_torch = type(arrays[0]).__module__.startswith("torch")
result_np = None
for i in range(rank, tree.multiplicity, size):
x = tree.contract_slice(arrays, i)
x_np = np.asarray(x.detach().cpu().numpy() if is_torch else x).reshape(-1)
result_np = x_np if result_np is None else result_np + x_np
if is_torch:
result_torch = None
for i in range(rank, tree.multiplicity, size):
x = tree.contract_slice(arrays, i, backend="torch").reshape(-1)
result_torch = x if result_torch is None else result_torch + x
if result_np is None:
result_np = np.zeros(1, dtype=np.complex128)
if result_torch is None:
result_np = np.zeros(1, dtype=np.complex128)
else:
result_np = result_torch.detach().cpu().numpy()
else:
result_np = None
for i in range(rank, tree.multiplicity, size):
x = tree.contract_slice(arrays, i)
x_np = np.asarray(x).reshape(-1)
result_np = x_np if result_np is None else result_np + x_np
if result_np is None:
result_np = np.zeros(1, dtype=np.complex128)
result = np.zeros_like(result_np) if rank == root else None
comm.Reduce(result_np, result, root=root)