From 54795745024c633ad9615a91c4cb1ae5c37e8942 Mon Sep 17 00:00:00 2001 From: jaunatisblue Date: Fri, 8 May 2026 11:57:18 +0800 Subject: [PATCH] =?UTF-8?q?=20=20=E4=BC=98=E5=8C=96=20torch=20CPU=20?= =?UTF-8?q?=E5=BC=A0=E9=87=8F=E7=BD=91=E7=BB=9C=E6=94=B6=E7=BC=A9=E8=B7=AF?= =?UTF-8?q?=E5=BE=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - torch CPU 收缩默认走 cotengra matmul lowering - 复用 mm/bmm/matmul 输出缓冲区,降低中间张量分配压力 - 仅回收 contiguous tensor,避免非连续 view 进入 workspace - 调整 cotengra 中间节点 index 顺序,减少 reshape 触发 clone/copy - qibotn MPI 分片收缩显式使用 backend=torch - rank 内分片结果先在 torch 中累加,最后再转 numpy 做 Reduce - 统一 quimb 后端 torch 数组转换为 CPU contiguous complex128 --- src/qibotn/backends/quimb.py | 57 +++++++++++++++++++++++++++++++++--- src/qibotn/parallel.py | 25 +++++++++++----- 2 files changed, 71 insertions(+), 11 deletions(-) diff --git a/src/qibotn/backends/quimb.py b/src/qibotn/backends/quimb.py index 6b28500..5d81606 100644 --- a/src/qibotn/backends/quimb.py +++ b/src/qibotn/backends/quimb.py @@ -38,6 +38,36 @@ GATE_MAP = { } +def _torch_cpu_array(data, dtype=None): + """Convert array-like data to a contiguous CPU torch tensor.""" + import numpy as np + import torch + + if isinstance(data, torch.Tensor): + x = data + else: + array = np.asarray(data) + if any(stride < 0 for stride in array.strides): + array = np.ascontiguousarray(array) + x = torch.from_numpy(array) + + if x.device.type != "cpu": + x = x.cpu() + if dtype is not None and x.dtype != dtype: + x = x.to(dtype) + if not x.is_contiguous(): + x = x.contiguous() + return x + + +def _arrays_to_backend(arrays, backend, engine): + if backend == "torch": + import torch + + return [_torch_cpu_array(array, dtype=torch.complex128) for array in arrays] + return [engine.asarray(array) for array in arrays] + + def __init__(self, quimb_backend="numpy", contraction_optimizer="auto-hq"): super(self.__class__, self).__init__() @@ -480,12 +510,31 @@ def _expectation_parallel(self, circuit, observable, method, opts): continue if mpi_contract and comm and size > 1: - arrays = [self.engine.asarray(a) for a in tn.arrays] + arrays = _arrays_to_backend(tn.arrays, self.backend, self.engine) val = parallel_contract(tree, arrays, method='mpi', comm=comm) else: - for tensor in tn.tensors: - tensor._data = torch.from_numpy(self.engine.asarray(tensor._data)).to(torch.complex128) - val = complex(tn.contract(all, output_inds=(), optimize=tree)) + if self.backend == "torch": + for tensor in tn.tensors: + tensor._data = _torch_cpu_array( + tensor._data, dtype=torch.complex128 + ) + val = complex( + tn.contract( + all, + output_inds=(), + optimize=tree, + backend="torch", + ) + ) + else: + val = complex( + tn.contract( + all, + output_inds=(), + optimize=tree, + backend=self.backend, + ) + ) my_exp += coeff * complex(val) diff --git a/src/qibotn/parallel.py b/src/qibotn/parallel.py index bc81ef3..5eee1f1 100644 --- a/src/qibotn/parallel.py +++ b/src/qibotn/parallel.py @@ -170,14 +170,25 @@ def _contract_mpi(tree, arrays, comm, root=0): rank, size = comm.Get_rank(), comm.Get_size() is_torch = type(arrays[0]).__module__.startswith("torch") - result_np = None - for i in range(rank, tree.multiplicity, size): - x = tree.contract_slice(arrays, i) - x_np = np.asarray(x.detach().cpu().numpy() if is_torch else x).reshape(-1) - result_np = x_np if result_np is None else result_np + x_np + if is_torch: + result_torch = None + for i in range(rank, tree.multiplicity, size): + x = tree.contract_slice(arrays, i, backend="torch").reshape(-1) + result_torch = x if result_torch is None else result_torch + x - if result_np is None: - result_np = np.zeros(1, dtype=np.complex128) + if result_torch is None: + result_np = np.zeros(1, dtype=np.complex128) + else: + result_np = result_torch.detach().cpu().numpy() + else: + result_np = None + for i in range(rank, tree.multiplicity, size): + x = tree.contract_slice(arrays, i) + x_np = np.asarray(x).reshape(-1) + result_np = x_np if result_np is None else result_np + x_np + + if result_np is None: + result_np = np.zeros(1, dtype=np.complex128) result = np.zeros_like(result_np) if rank == root else None comm.Reduce(result_np, result, root=root)