diff --git a/src/qibotn/backends/gpu.py b/src/qibotn/backends/gpu.py index 6301bf7..ad82b75 100644 --- a/src/qibotn/backends/gpu.py +++ b/src/qibotn/backends/gpu.py @@ -8,11 +8,24 @@ from qibo.config import raise_error class CuTensorNet(NumpyBackend): # pragma: no cover # CI does not test for GPU - def __init__(self): + def __init__(self, runcard): super().__init__() import cuquantum # pylint: disable=import-error from cuquantum import cutensornet as cutn # pylint: disable=import-error + if runcard is not None: + print("inside runcard") + # Parse the runcard or use its values to set flags + self.MPI_enabled = runcard.get("MPI_enabled", False) + self.MPS_enabled = runcard.get("MPS_enabled", False) + self.NCCL_enabled = runcard.get("NCCL_enabled", False) + self.expectation_enabled = runcard.get("expectation_enabled", False) + else: + self.MPI_enabled = False + self.MPS_enabled = False + self.NCCL_enabled = False + self.expectation_enabled = False + self.name = "qibotn" self.cuquantum = cuquantum self.cutn = cutn @@ -53,7 +66,7 @@ class CuTensorNet(NumpyBackend): # pragma: no cover raise TypeError("Type can be either complex64 or complex128") def execute_circuit( - self, circuit, MPI_enabled=False, MPS_enabled=False, NCCL_enabled=False, expectation_enabled=False, initial_state=None, nshots=None, return_array=False + self, circuit, initial_state=None, nshots=None, return_array=False ): # pragma: no cover """Executes a quantum circuit. @@ -68,32 +81,31 @@ class CuTensorNet(NumpyBackend): # pragma: no cover """ import qibotn.eval as eval - print("MPI_enabled", MPI_enabled) - print("MPS_enabled", MPS_enabled) - print("NCCL_enabled", NCCL_enabled) - print("expectation_enabled", expectation_enabled) - + + print("MPI_enabled", self.MPI_enabled) + print("MPS_enabled", self.MPS_enabled) + print("NCCL_enabled", self.NCCL_enabled) + print("expectation_enabled", self.expectation_enabled) + if ( - MPI_enabled == False - and MPS_enabled == False - and NCCL_enabled == False - and expectation_enabled == False + self.MPI_enabled == False + and self.MPS_enabled == False + and self.NCCL_enabled == False + and self.expectation_enabled == False ): if initial_state is not None: - raise_error(NotImplementedError, - "QiboTN cannot support initial state.") + raise_error(NotImplementedError, "QiboTN cannot support initial state.") state = eval.dense_vector_tn(circuit, self.dtype) - if ( - MPI_enabled == False - and MPS_enabled == True - and NCCL_enabled == False - and expectation_enabled == False + elif ( + self.MPI_enabled == False + and self.MPS_enabled == True + and self.NCCL_enabled == False + and self.expectation_enabled == False ): if initial_state is not None: - raise_error(NotImplementedError, - "QiboTN cannot support initial state.") + raise_error(NotImplementedError, "QiboTN cannot support initial state.") gate_algo = { "qr_method": False, @@ -104,81 +116,75 @@ class CuTensorNet(NumpyBackend): # pragma: no cover } # make this user input state = eval.dense_vector_mps(circuit, gate_algo, self.dtype) - if ( - MPI_enabled == True - and MPS_enabled == False - and NCCL_enabled == False - and expectation_enabled == False + elif ( + self.MPI_enabled == True + and self.MPS_enabled == False + and self.NCCL_enabled == False + and self.expectation_enabled == False ): if initial_state is not None: - raise_error(NotImplementedError, - "QiboTN cannot support initial state.") + raise_error(NotImplementedError, "QiboTN cannot support initial state.") state, rank = eval.dense_vector_tn_MPI(circuit, self.dtype, 32) if rank > 0: state = np.array(0) - if ( - MPI_enabled == False - and MPS_enabled == False - and NCCL_enabled == True - and expectation_enabled == False + elif ( + self.MPI_enabled == False + and self.MPS_enabled == False + and self.NCCL_enabled == True + and self.expectation_enabled == False ): if initial_state is not None: - raise_error(NotImplementedError, - "QiboTN cannot support initial state.") + raise_error(NotImplementedError, "QiboTN cannot support initial state.") state, rank = eval.dense_vector_tn_nccl(circuit, self.dtype, 32) if rank > 0: state = np.array(0) - if ( - MPI_enabled == False - and MPS_enabled == False - and NCCL_enabled == False - and expectation_enabled == True + elif ( + self.MPI_enabled == False + and self.MPS_enabled == False + and self.NCCL_enabled == False + and self.expectation_enabled == True ): if initial_state is not None: - raise_error(NotImplementedError, - "QiboTN cannot support initial state.") + raise_error(NotImplementedError, "QiboTN cannot support initial state.") - state = eval.expectation_tn(circuit, self.dtype) + state = eval.expectation_pauli_tn(circuit, self.dtype) - if ( - MPI_enabled == True - and MPS_enabled == False - and NCCL_enabled == False - and expectation_enabled == True + elif ( + self.MPI_enabled == True + and self.MPS_enabled == False + and self.NCCL_enabled == False + and self.expectation_enabled == True ): if initial_state is not None: - raise_error(NotImplementedError, - "QiboTN cannot support initial state.") + raise_error(NotImplementedError, "QiboTN cannot support initial state.") - state, rank = eval.expectation_pauli_tn_MPI( - circuit, self.dtype, 32) + state, rank = eval.expectation_pauli_tn_MPI(circuit, self.dtype, 32) if rank > 0: state = np.array(0) - if ( - MPI_enabled == False - and MPS_enabled == False - and NCCL_enabled == True - and expectation_enabled == True + elif ( + self.MPI_enabled == False + and self.MPS_enabled == False + and self.NCCL_enabled == True + and self.expectation_enabled == True ): if initial_state is not None: - raise_error(NotImplementedError, - "QiboTN cannot support initial state.") + raise_error(NotImplementedError, "QiboTN cannot support initial state.") - state, rank = eval.expectation_pauli_tn_nccl( - circuit, self.dtype, 32) + state, rank = eval.expectation_pauli_tn_nccl(circuit, self.dtype, 32) if rank > 0: state = np.array(0) + else: + raise_error(NotImplementedError, "Backend not supported.") if return_array: return state.flatten() else: - circuit._final_state = CircuitResult( - self, circuit, state.flatten(), nshots) + circuit._final_state = CircuitResult(self, circuit, state.flatten(), nshots) return circuit._final_state