Take out repeat codes

This commit is contained in:
tankya2
2024-02-09 10:29:30 +08:00
committed by yangliwei
parent edf64ebeb3
commit 2b6345c8ca

View File

@@ -110,96 +110,70 @@ class CuTensorNet(NumpyBackend): # pragma: no cover
import qibotn.eval as eval import qibotn.eval as eval
if initial_state is not None:
raise_error(NotImplementedError, "QiboTN cannot support initial state.")
if ( if (
self.MPI_enabled == False self.MPI_enabled == False
and self.MPS_enabled == False and self.MPS_enabled == False
and self.NCCL_enabled == False and self.NCCL_enabled == False
and self.expectation_enabled == False and self.expectation_enabled == False
): ):
if initial_state is not None:
raise_error(NotImplementedError, "QiboTN cannot support initial state.")
state = eval.dense_vector_tn(circuit, self.dtype) state = eval.dense_vector_tn(circuit, self.dtype)
elif ( elif (
self.MPI_enabled == False self.MPI_enabled == False
and self.MPS_enabled == True and self.MPS_enabled == True
and self.NCCL_enabled == False and self.NCCL_enabled == False
and self.expectation_enabled == False and self.expectation_enabled == False
): ):
if initial_state is not None:
raise_error(NotImplementedError, "QiboTN cannot support initial state.")
state = eval.dense_vector_mps(circuit, self.gate_algo, self.dtype) state = eval.dense_vector_mps(circuit, self.gate_algo, self.dtype)
elif ( elif (
self.MPI_enabled == True self.MPI_enabled == True
and self.MPS_enabled == False and self.MPS_enabled == False
and self.NCCL_enabled == False and self.NCCL_enabled == False
and self.expectation_enabled == False and self.expectation_enabled == False
): ):
if initial_state is not None:
raise_error(NotImplementedError, "QiboTN cannot support initial state.")
state, rank = eval.dense_vector_tn_MPI(circuit, self.dtype, 32) state, rank = eval.dense_vector_tn_MPI(circuit, self.dtype, 32)
if rank > 0: if rank > 0:
state = np.array(0) state = np.array(0)
elif ( elif (
self.MPI_enabled == False self.MPI_enabled == False
and self.MPS_enabled == False and self.MPS_enabled == False
and self.NCCL_enabled == True and self.NCCL_enabled == True
and self.expectation_enabled == False and self.expectation_enabled == False
): ):
if initial_state is not None:
raise_error(NotImplementedError, "QiboTN cannot support initial state.")
state, rank = eval.dense_vector_tn_nccl(circuit, self.dtype, 32) state, rank = eval.dense_vector_tn_nccl(circuit, self.dtype, 32)
if rank > 0: if rank > 0:
state = np.array(0) state = np.array(0)
elif ( elif (
self.MPI_enabled == False self.MPI_enabled == False
and self.MPS_enabled == False and self.MPS_enabled == False
and self.NCCL_enabled == False and self.NCCL_enabled == False
and self.expectation_enabled == True and self.expectation_enabled == True
): ):
if initial_state is not None:
raise_error(NotImplementedError, "QiboTN cannot support initial state.")
state = eval.expectation_pauli_tn( state = eval.expectation_pauli_tn(
circuit, self.dtype, self.pauli_string_pattern circuit, self.dtype, self.pauli_string_pattern
) )
elif ( elif (
self.MPI_enabled == True self.MPI_enabled == True
and self.MPS_enabled == False and self.MPS_enabled == False
and self.NCCL_enabled == False and self.NCCL_enabled == False
and self.expectation_enabled == True and self.expectation_enabled == True
): ):
if initial_state is not None:
raise_error(NotImplementedError, "QiboTN cannot support initial state.")
state, rank = eval.expectation_pauli_tn_MPI( state, rank = eval.expectation_pauli_tn_MPI(
circuit, self.dtype, self.pauli_string_pattern, 32 circuit, self.dtype, self.pauli_string_pattern, 32
) )
if rank > 0: if rank > 0:
state = np.array(0) state = np.array(0)
elif ( elif (
self.MPI_enabled == False self.MPI_enabled == False
and self.MPS_enabled == False and self.MPS_enabled == False
and self.NCCL_enabled == True and self.NCCL_enabled == True
and self.expectation_enabled == True and self.expectation_enabled == True
): ):
if initial_state is not None:
raise_error(NotImplementedError, "QiboTN cannot support initial state.")
state, rank = eval.expectation_pauli_tn_nccl( state, rank = eval.expectation_pauli_tn_nccl(
circuit, self.dtype, self.pauli_string_pattern, 32 circuit, self.dtype, self.pauli_string_pattern, 32
) )
if rank > 0: if rank > 0:
state = np.array(0) state = np.array(0)
else: else: