diff --git a/src/qibotn/cutn.py b/src/qibotn/cutn.py index 985d7bb..36f7866 100644 --- a/src/qibotn/cutn.py +++ b/src/qibotn/cutn.py @@ -7,4 +7,4 @@ def eval(qibo_circ, datatype): myconvertor = QiboCircuitToEinsum(qibo_circ, dtype=datatype) operands_expression = myconvertor.state_vector() results = contract(*operands_expression) - return results.flatten() + return results diff --git a/tests/test_cuquantum_cutensor_backend.py b/tests/test_cuquantum_cutensor_backend.py index e438b24..e7f2804 100644 --- a/tests/test_cuquantum_cutensor_backend.py +++ b/tests/test_cuquantum_cutensor_backend.py @@ -40,7 +40,8 @@ def test_eval(nqubits: int, dtype="complex128"): lambda: qibo_qft(nqubits, swaps=True)) # Test Cuquantum - cutn_time, result_tn = time(lambda: qibotn.cutn.eval(qibo_circ, dtype)) + cutn_time, result_tn = time( + lambda: qibotn.cutn.eval(qibo_circ, dtype).flatten()) assert 1e-2 * qibo_time < cutn_time < 1e2 * qibo_time assert np.allclose(