From 65a04c32fa38e4c00e2b4785e8f2f4117d0d45ba Mon Sep 17 00:00:00 2001 From: tankya2 Date: Tue, 26 Aug 2025 17:11:08 +0800 Subject: [PATCH] Make rank class attribute --- src/qibotn/backends/cutensornet.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/src/qibotn/backends/cutensornet.py b/src/qibotn/backends/cutensornet.py index ed1c7f2..616cda9 100644 --- a/src/qibotn/backends/cutensornet.py +++ b/src/qibotn/backends/cutensornet.py @@ -22,6 +22,7 @@ class CuTensorNet(QibotnBackend, NumpyBackend): # pragma: no cover self.configure_tn_simulation(runcard) def configure_tn_simulation(self, runcard): + self.rank = None if runcard is not None: self.MPI_enabled = runcard.get("MPI_enabled", False) self.NCCL_enabled = runcard.get("NCCL_enabled", False) @@ -106,8 +107,8 @@ class CuTensorNet(QibotnBackend, NumpyBackend): # pragma: no cover and self.NCCL_enabled == False and self.expectation_enabled == False ): - state, rank = eval.dense_vector_tn_MPI(circuit, self.dtype, 32) - if rank > 0: + state, self.rank = eval.dense_vector_tn_MPI(circuit, self.dtype, 32) + if self.rank > 0: state = np.array(0) elif ( self.MPI_enabled == False @@ -115,8 +116,8 @@ class CuTensorNet(QibotnBackend, NumpyBackend): # pragma: no cover and self.NCCL_enabled == True and self.expectation_enabled == False ): - state, rank = eval.dense_vector_tn_nccl(circuit, self.dtype, 32) - if rank > 0: + state, self.rank = eval.dense_vector_tn_nccl(circuit, self.dtype, 32) + if self.rank > 0: state = np.array(0) elif ( self.MPI_enabled == False @@ -131,10 +132,10 @@ class CuTensorNet(QibotnBackend, NumpyBackend): # pragma: no cover and self.NCCL_enabled == False and self.expectation_enabled == True ): - state, rank = eval.expectation_tn_MPI( + state, self.rank = eval.expectation_tn_MPI( circuit, self.dtype, self.observable, 32 ) - if rank > 0: + if self.rank > 0: state = np.array(0) elif ( self.MPI_enabled == False @@ -142,10 +143,10 @@ class CuTensorNet(QibotnBackend, NumpyBackend): # pragma: no cover and self.NCCL_enabled == True and self.expectation_enabled == True ): - state, rank = eval.expectation_tn_nccl( + state, self.rank = eval.expectation_tn_nccl( circuit, self.dtype, self.observable, 32 ) - if rank > 0: + if self.rank > 0: state = np.array(0) else: raise_error(NotImplementedError, "Compute type not supported.")