diff --git a/src/qibotn/backends/cutensornet.py b/src/qibotn/backends/cutensornet.py index 78ee5b6..022b8af 100644 --- a/src/qibotn/backends/cutensornet.py +++ b/src/qibotn/backends/cutensornet.py @@ -2,6 +2,18 @@ import numpy as np from qibo.backends.numpy import NumpyBackend from qibo.config import raise_error from qibo.result import CircuitResult +import cuquantum # pylint: disable=import-error + +CUDA_TYPES = { + "complex64": ( + cuquantum.cudaDataType.CUDA_C_32F, + cuquantum.ComputeType.COMPUTE_32F, + ), + "complex128": ( + cuquantum.cudaDataType.CUDA_C_64F, + cuquantum.ComputeType.COMPUTE_64F, + ), +} class CuTensorNet(NumpyBackend): # pragma: no cover @@ -9,7 +21,6 @@ class CuTensorNet(NumpyBackend): # pragma: no cover def __init__(self, runcard): super().__init__() - import cuquantum # pylint: disable=import-error from cuquantum import cutensornet as cutn # pylint: disable=import-error if runcard is not None: @@ -81,16 +92,8 @@ class CuTensorNet(NumpyBackend): # pragma: no cover super().set_precision(precision) def cuda_type(self, dtype="complex64"): - if dtype == "complex128": - return ( - self.cuquantum.cudaDataType.CUDA_C_64F, - self.cuquantum.ComputeType.COMPUTE_64F, - ) - elif dtype == "complex64": - return ( - self.cuquantum.cudaDataType.CUDA_C_32F, - self.cuquantum.ComputeType.COMPUTE_32F, - ) + if dtype in CUDA_TYPES: + return CUDA_TYPES[dtype] else: raise TypeError("Type can be either complex64 or complex128")