Make runcard optional in init plus some refactoring

This commit is contained in:
tankya2
2025-08-22 15:41:03 +08:00
parent 791c5d2020
commit 410a742cc3

View File

@@ -6,22 +6,22 @@ from qibo.config import raise_error
from qibotn.backends.abstract import QibotnBackend
from qibotn.result import TensorNetworkResult
CUDA_TYPES = {}
class CuTensorNet(QibotnBackend, NumpyBackend): # pragma: no cover
# CI does not test for GPU
"""Creates CuQuantum backend for QiboTN."""
def __init__(self, runcard):
def __init__(self, runcard=None):
super().__init__()
from cuquantum import ( # pylint: disable=import-error
ComputeType,
__version__,
cudaDataType,
)
from cuquantum import cutensornet as cutn # pylint: disable=import-error
from cuquantum import __version__ # pylint: disable=import-error
self.name = "qibotn"
self.platform = "cutensornet"
self.versions["cuquantum"] = __version__
self.supports_multigpu = True
self.configure_tn_simulation(runcard)
def configure_tn_simulation(self, runcard):
if runcard is not None:
self.MPI_enabled = runcard.get("MPI_enabled", False)
self.NCCL_enabled = runcard.get("NCCL_enabled", False)
@@ -67,43 +67,6 @@ class CuTensorNet(QibotnBackend, NumpyBackend): # pragma: no cover
self.NCCL_enabled = False
self.expectation_enabled = False
self.name = "qibotn"
self.cutn = cutn
self.platform = "cutensornet"
self.versions["cuquantum"] = __version__
self.supports_multigpu = True
self.handle = self.cutn.create()
global CUDA_TYPES
CUDA_TYPES = {
"complex64": (
cudaDataType.CUDA_C_32F,
ComputeType.COMPUTE_32F,
),
"complex128": (
cudaDataType.CUDA_C_64F,
ComputeType.COMPUTE_64F,
),
}
def __del__(self):
if hasattr(self, "cutn"):
self.cutn.destroy(self.handle)
def cuda_type(self, dtype="complex64"):
"""Get CUDA Type.
Parameters:
dtype (str, optional): Either single ("complex64") or double (complex128) precision. Defaults to "complex64".
Returns:
CUDA Type: tuple of cuquantum.cudaDataType and cuquantum.ComputeType
"""
if dtype in CUDA_TYPES:
return CUDA_TYPES[dtype]
else:
raise TypeError("Type can be either complex64 or complex128")
def execute_circuit(
self, circuit, initial_state=None, nshots=None, return_array=False
): # pragma: no cover