From e517b4fe7c6a9d15ed03e5e8b55c6b17ba625b03 Mon Sep 17 00:00:00 2001 From: Liwei Yang Date: Wed, 19 Apr 2023 16:23:59 +0800 Subject: [PATCH] Avoid flatten() so as to keep the shape information of contraction results --- src/qibotn/cutn.py | 2 +- tests/test_cuquantum_cutensor_backend.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/qibotn/cutn.py b/src/qibotn/cutn.py index 985d7bb..36f7866 100644 --- a/src/qibotn/cutn.py +++ b/src/qibotn/cutn.py @@ -7,4 +7,4 @@ def eval(qibo_circ, datatype): myconvertor = QiboCircuitToEinsum(qibo_circ, dtype=datatype) operands_expression = myconvertor.state_vector() results = contract(*operands_expression) - return results.flatten() + return results diff --git a/tests/test_cuquantum_cutensor_backend.py b/tests/test_cuquantum_cutensor_backend.py index e438b24..e7f2804 100644 --- a/tests/test_cuquantum_cutensor_backend.py +++ b/tests/test_cuquantum_cutensor_backend.py @@ -40,7 +40,8 @@ def test_eval(nqubits: int, dtype="complex128"): lambda: qibo_qft(nqubits, swaps=True)) # Test Cuquantum - cutn_time, result_tn = time(lambda: qibotn.cutn.eval(qibo_circ, dtype)) + cutn_time, result_tn = time( + lambda: qibotn.cutn.eval(qibo_circ, dtype).flatten()) assert 1e-2 * qibo_time < cutn_time < 1e2 * qibo_time assert np.allclose(