diff --git a/src/qibotn/__main__.py b/src/qibotn/__main__.py index f81bfd3..0476be5 100644 --- a/src/qibotn/__main__.py +++ b/src/qibotn/__main__.py @@ -1,11 +1,6 @@ import argparse -from timeit import default_timer as timer -from qibotn import quimb as qiboquimb -from QiboCircuitConvertor import QiboCircuitToEinsum -from cuquantum import contract -import cupy as cp -from qibo.models import QFT +import qibotn.quimb def parser(): @@ -18,65 +13,8 @@ def parser(): def main(args: argparse.Namespace): print("Testing for %d nqubits" % (args.nqubits)) - qiboquimb.eval(args.nqubits, args.qasm_circ, args.init_state) - - -def parser_cuquantum(): - parser = argparse.ArgumentParser() - - parser.add_argument( - "--nqubits", default=10, type=int, help="Number of quibits in the circuits." - ) - - parser.add_argument( - "--circuit", - default="qft", - type=str, - help="Type of circuit to use. See README for the list of " - "available circuits.", - ) - - parser.add_argument( - "--precision", - default="complex128", - type=str, - help="Numerical precision of the simulation. " - "Choose between 'complex128' and 'complex64'.", - ) - - return parser.parse_args() - - -def run_bench(task, label): - start = timer() - result = task() - end = timer() - circuit_eval_time = end - start - print(f"Simulation time: {label} = {circuit_eval_time}s") - - return result - - -def main_cuquantum(args: argparse.Namespace): - print("Testing for %d nqubits" % (args.nqubits)) - nqubits = args.nqubits - circuit_name = args.circuit - datatype = args.precision - - if circuit_name in ("qft", "QFT"): - circuit = QFT(nqubits) - else: - raise NotImplementedError(f"Cannot find circuit {circuit_name}.") - - myconvertor = QiboCircuitToEinsum(circuit, dtype=datatype) - operands_expression = myconvertor.state_vector() - - result_qibo = run_bench(circuit, "Qibo") - sv_cutn = run_bench(lambda: contract(*operands_expression), "cuQuantum cuTensorNet") - - # print(f"is sv in agreement?", cp.allclose(sv_cutn.flatten(), result_qibo.state(numpy=True))) - assert cp.allclose(sv_cutn.flatten(), result_qibo.state(numpy=True)) + qibotn.quimb.eval(args.nqubits, args.qasm_circ, args.init_state) if __name__ == "__main__": - main_cuquantum(parser_cuquantum()) + main(parser()) diff --git a/src/qibotn/cutn.py b/src/qibotn/cutn.py index 026604a..36f7866 100644 --- a/src/qibotn/cutn.py +++ b/src/qibotn/cutn.py @@ -1,10 +1,10 @@ # from qibotn import quimb as qiboquimb -from QiboCircuitConvertor import QiboCircuitToEinsum +from qibotn.QiboCircuitConvertor import QiboCircuitToEinsum from cuquantum import contract -def eval(qibo_circ,datatype): +def eval(qibo_circ, datatype): myconvertor = QiboCircuitToEinsum(qibo_circ, dtype=datatype) operands_expression = myconvertor.state_vector() results = contract(*operands_expression) - return results.flatten() + return results diff --git a/tests/test_cuquantum_cutensor_backend.py b/tests/test_cuquantum_cutensor_backend.py index 30823ef..e7f2804 100644 --- a/tests/test_cuquantum_cutensor_backend.py +++ b/tests/test_cuquantum_cutensor_backend.py @@ -1,5 +1,3 @@ -import copy -import os from timeit import default_timer as timer import config @@ -22,18 +20,29 @@ def time(func): time = end - start return time, res + @pytest.mark.gpu @pytest.mark.parametrize("nqubits", [1, 2, 5, 10]) -def test_eval(nqubits: int): +def test_eval(nqubits: int, dtype="complex128"): + """Evaluate QASM with cuQuantum. + + Args: + nqubits (int): Total number of qubits in the system. + dtype (str): The data type for precision, 'complex64' for single, + 'complex128' for double. + """ import qibotn.cutn # Test qibo - qibo.set_backend(backend=config.qibo.backend, platform=config.qibo.platform) - qibo_time, (qibo_circ, result_sv) = time(lambda: qibo_qft(nqubits, swaps=True)) + qibo.set_backend(backend=config.qibo.backend, + platform=config.qibo.platform) + qibo_time, (qibo_circ, result_sv) = time( + lambda: qibo_qft(nqubits, swaps=True)) # Test Cuquantum - data_type = "complex128" - cutn_time, result_tn = time(lambda: qibotn.cutn.eval(qibo_circ,data_type)) + cutn_time, result_tn = time( + lambda: qibotn.cutn.eval(qibo_circ, dtype).flatten()) assert 1e-2 * qibo_time < cutn_time < 1e2 * qibo_time - assert np.allclose(result_sv, result_tn), "Resulting dense vectors do not match" + assert np.allclose( + result_sv, result_tn), "Resulting dense vectors do not match"