From aee995802f9986ae3bfa8c0151ce0a98faa300f2 Mon Sep 17 00:00:00 2001 From: tankya2 Date: Tue, 18 Apr 2023 11:34:00 +0800 Subject: [PATCH] Add datatype as an input in eval() [skip ci] --- src/qibotn/cutn.py | 4 ++-- tests/test_cuquantum_cutensor_backend.py | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/qibotn/cutn.py b/src/qibotn/cutn.py index fe0ae38..026604a 100644 --- a/src/qibotn/cutn.py +++ b/src/qibotn/cutn.py @@ -3,8 +3,8 @@ from QiboCircuitConvertor import QiboCircuitToEinsum from cuquantum import contract -def eval(qibo_circ): - myconvertor = QiboCircuitToEinsum(qibo_circ, dtype="complex128") +def eval(qibo_circ,datatype): + myconvertor = QiboCircuitToEinsum(qibo_circ, dtype=datatype) operands_expression = myconvertor.state_vector() results = contract(*operands_expression) return results.flatten() diff --git a/tests/test_cuquantum_cutensor_backend.py b/tests/test_cuquantum_cutensor_backend.py index 07bb337..30823ef 100644 --- a/tests/test_cuquantum_cutensor_backend.py +++ b/tests/test_cuquantum_cutensor_backend.py @@ -32,7 +32,8 @@ def test_eval(nqubits: int): qibo_time, (qibo_circ, result_sv) = time(lambda: qibo_qft(nqubits, swaps=True)) # Test Cuquantum - cutn_time, result_tn = time(lambda: qibotn.cutn.eval(qibo_circ)) + data_type = "complex128" + cutn_time, result_tn = time(lambda: qibotn.cutn.eval(qibo_circ,data_type)) assert 1e-2 * qibo_time < cutn_time < 1e2 * qibo_time assert np.allclose(result_sv, result_tn), "Resulting dense vectors do not match"