Make datatype a constant dict

This commit is contained in:
tankya2
2024-02-16 15:41:46 +08:00
parent 5c24cc34c5
commit bea3af6f77

View File

@@ -2,6 +2,18 @@ import numpy as np
from qibo.backends.numpy import NumpyBackend
from qibo.config import raise_error
from qibo.result import CircuitResult
import cuquantum # pylint: disable=import-error
CUDA_TYPES = {
"complex64": (
cuquantum.cudaDataType.CUDA_C_32F,
cuquantum.ComputeType.COMPUTE_32F,
),
"complex128": (
cuquantum.cudaDataType.CUDA_C_64F,
cuquantum.ComputeType.COMPUTE_64F,
),
}
class CuTensorNet(NumpyBackend): # pragma: no cover
@@ -9,7 +21,6 @@ class CuTensorNet(NumpyBackend): # pragma: no cover
def __init__(self, runcard):
super().__init__()
import cuquantum # pylint: disable=import-error
from cuquantum import cutensornet as cutn # pylint: disable=import-error
if runcard is not None:
@@ -81,16 +92,8 @@ class CuTensorNet(NumpyBackend): # pragma: no cover
super().set_precision(precision)
def cuda_type(self, dtype="complex64"):
if dtype == "complex128":
return (
self.cuquantum.cudaDataType.CUDA_C_64F,
self.cuquantum.ComputeType.COMPUTE_64F,
)
elif dtype == "complex64":
return (
self.cuquantum.cudaDataType.CUDA_C_32F,
self.cuquantum.ComputeType.COMPUTE_32F,
)
if dtype in CUDA_TYPES:
return CUDA_TYPES[dtype]
else:
raise TypeError("Type can be either complex64 or complex128")