diff --git a/src/qibotn/cutn.py b/src/qibotn/cutn.py index 9bc1d67..3d42eb7 100644 --- a/src/qibotn/cutn.py +++ b/src/qibotn/cutn.py @@ -4,6 +4,7 @@ from cuquantum import cutensornet as cutn from mpi4py import MPI # this line initializes MPI import multiprocessing from cupy.cuda.runtime import getDeviceCount +import cupy as cp def eval(qibo_circ, datatype): @@ -37,23 +38,3 @@ def eval_tn_MPI(qibo_circ, datatype): cutn.destroy(handle) return result, rank - - -if __name__ == "__main__": - from qibo.models import QFT - import cupy as cp - import numpy as np - - num_qubits = 10 - swaps = True - circ_qibo = QFT(num_qubits, swaps) - - dtype = "complex128" - sv_mpi, rank = eval_tn_MPI(circ_qibo, dtype) - - if rank == 0: - sv_reference = eval(circ_qibo, dtype) - state_vec = np.array(circ_qibo()) - print(f"State vector difference: {abs(sv_mpi-sv_reference).max():0.3e}") - assert cp.allclose(sv_mpi, sv_reference) - assert cp.allclose(sv_mpi.flatten(), state_vec)