diff --git a/src/qibotn/__main__.py b/src/qibotn/__main__.py index b42f84d..c09b2fc 100644 --- a/src/qibotn/__main__.py +++ b/src/qibotn/__main__.py @@ -5,7 +5,7 @@ from qibotn import quimb as qiboquimb from QiboCircuitConvertor import QiboCircuitToEinsum from cuquantum import contract import cupy as cp -from qibo.models import * +from qibo.models import QFT def parser(): @@ -47,6 +47,16 @@ def parser_cuquantum(): return parser.parse_args() +def run_bench(task, label): + start = timer() + result = task() + end = timer() + circuit_eval_time = end - start + print(f"Simulation time: {label} = {circuit_eval_time}s") + + return result + + def main_cuquantum(args: argparse.Namespace): print("Testing for %d nqubits" % (args.nqubits)) nqubits = args.nqubits @@ -60,20 +70,12 @@ def main_cuquantum(args: argparse.Namespace): raise NotImplementedError(f"Cannot find circuit {circuit_name}.") myconvertor = QiboCircuitToEinsum(circuit, dtype=datatype) - expression, operands = myconvertor.state_vector() - start = timer() - result_qibo = circuit() - end = timer() - circuit_eval_time = end - start - print("Simulation time: Qibo =", circuit_eval_time, "s") - - start = timer() - sv_cutn = contract(expression, *operands) - end = timer() - circuit_eval_time = end - start - print("Simulation time: cuQuantum cuTensorNet =", circuit_eval_time, "s") + result_qibo = run_bench(circuit, "Qibo") + sv_cutn = run_bench( + lambda: contract(expression, *operands), "cuQuantum cuTensorNet" + ) # print(f"is sv in agreement?", cp.allclose(sv_cutn.flatten(), result_qibo.state(numpy=True))) assert cp.allclose(sv_cutn.flatten(), result_qibo.state(numpy=True))