From 4d36afb9efde1478a0f274670ed918dd77578bf4 Mon Sep 17 00:00:00 2001 From: Liwei Yang Date: Wed, 19 Apr 2023 16:11:14 +0800 Subject: [PATCH] Expose the precision dtype to the caller so that users can specify the precision for testing --- tests/test_cuquantum_cutensor_backend.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/tests/test_cuquantum_cutensor_backend.py b/tests/test_cuquantum_cutensor_backend.py index 74cf2e1..e438b24 100644 --- a/tests/test_cuquantum_cutensor_backend.py +++ b/tests/test_cuquantum_cutensor_backend.py @@ -23,7 +23,14 @@ def time(func): @pytest.mark.gpu @pytest.mark.parametrize("nqubits", [1, 2, 5, 10]) -def test_eval(nqubits: int): +def test_eval(nqubits: int, dtype="complex128"): + """Evaluate QASM with cuQuantum. + + Args: + nqubits (int): Total number of qubits in the system. + dtype (str): The data type for precision, 'complex64' for single, + 'complex128' for double. + """ import qibotn.cutn # Test qibo @@ -33,8 +40,7 @@ def test_eval(nqubits: int): lambda: qibo_qft(nqubits, swaps=True)) # Test Cuquantum - data_type = "complex128" - cutn_time, result_tn = time(lambda: qibotn.cutn.eval(qibo_circ, data_type)) + cutn_time, result_tn = time(lambda: qibotn.cutn.eval(qibo_circ, dtype)) assert 1e-2 * qibo_time < cutn_time < 1e2 * qibo_time assert np.allclose(