From 410a742cc35d8485a601784e18da496f9572a4b8 Mon Sep 17 00:00:00 2001 From: tankya2 Date: Fri, 22 Aug 2025 15:41:03 +0800 Subject: [PATCH] Make runcard optional in init plus some refactoring --- src/qibotn/backends/cutensornet.py | 55 +++++------------------------- 1 file changed, 9 insertions(+), 46 deletions(-) diff --git a/src/qibotn/backends/cutensornet.py b/src/qibotn/backends/cutensornet.py index 7dd1091..ed1c7f2 100644 --- a/src/qibotn/backends/cutensornet.py +++ b/src/qibotn/backends/cutensornet.py @@ -6,22 +6,22 @@ from qibo.config import raise_error from qibotn.backends.abstract import QibotnBackend from qibotn.result import TensorNetworkResult -CUDA_TYPES = {} - class CuTensorNet(QibotnBackend, NumpyBackend): # pragma: no cover # CI does not test for GPU """Creates CuQuantum backend for QiboTN.""" - def __init__(self, runcard): + def __init__(self, runcard=None): super().__init__() - from cuquantum import ( # pylint: disable=import-error - ComputeType, - __version__, - cudaDataType, - ) - from cuquantum import cutensornet as cutn # pylint: disable=import-error + from cuquantum import __version__ # pylint: disable=import-error + self.name = "qibotn" + self.platform = "cutensornet" + self.versions["cuquantum"] = __version__ + self.supports_multigpu = True + self.configure_tn_simulation(runcard) + + def configure_tn_simulation(self, runcard): if runcard is not None: self.MPI_enabled = runcard.get("MPI_enabled", False) self.NCCL_enabled = runcard.get("NCCL_enabled", False) @@ -67,43 +67,6 @@ class CuTensorNet(QibotnBackend, NumpyBackend): # pragma: no cover self.NCCL_enabled = False self.expectation_enabled = False - self.name = "qibotn" - self.cutn = cutn - self.platform = "cutensornet" - self.versions["cuquantum"] = __version__ - self.supports_multigpu = True - self.handle = self.cutn.create() - - global CUDA_TYPES - CUDA_TYPES = { - "complex64": ( - cudaDataType.CUDA_C_32F, - ComputeType.COMPUTE_32F, - ), - "complex128": ( - cudaDataType.CUDA_C_64F, - ComputeType.COMPUTE_64F, - ), - } - - def __del__(self): - if hasattr(self, "cutn"): - self.cutn.destroy(self.handle) - - def cuda_type(self, dtype="complex64"): - """Get CUDA Type. - - Parameters: - dtype (str, optional): Either single ("complex64") or double (complex128) precision. Defaults to "complex64". - - Returns: - CUDA Type: tuple of cuquantum.cudaDataType and cuquantum.ComputeType - """ - if dtype in CUDA_TYPES: - return CUDA_TYPES[dtype] - else: - raise TypeError("Type can be either complex64 or complex128") - def execute_circuit( self, circuit, initial_state=None, nshots=None, return_array=False ): # pragma: no cover