From b2a2bfedf1fd424607219c091345a7b6362ca3eb Mon Sep 17 00:00:00 2001 From: tankya2 Date: Tue, 3 Oct 2023 14:25:28 +0800 Subject: [PATCH] Removed main and added cupy import --- src/qibotn/cutn.py | 21 +-------------------- 1 file changed, 1 insertion(+), 20 deletions(-) diff --git a/src/qibotn/cutn.py b/src/qibotn/cutn.py index 9bc1d67..3d42eb7 100644 --- a/src/qibotn/cutn.py +++ b/src/qibotn/cutn.py @@ -4,6 +4,7 @@ from cuquantum import cutensornet as cutn from mpi4py import MPI # this line initializes MPI import multiprocessing from cupy.cuda.runtime import getDeviceCount +import cupy as cp def eval(qibo_circ, datatype): @@ -37,23 +38,3 @@ def eval_tn_MPI(qibo_circ, datatype): cutn.destroy(handle) return result, rank - - -if __name__ == "__main__": - from qibo.models import QFT - import cupy as cp - import numpy as np - - num_qubits = 10 - swaps = True - circ_qibo = QFT(num_qubits, swaps) - - dtype = "complex128" - sv_mpi, rank = eval_tn_MPI(circ_qibo, dtype) - - if rank == 0: - sv_reference = eval(circ_qibo, dtype) - state_vec = np.array(circ_qibo()) - print(f"State vector difference: {abs(sv_mpi-sv_reference).max():0.3e}") - assert cp.allclose(sv_mpi, sv_reference) - assert cp.allclose(sv_mpi.flatten(), state_vec)