From 9890d1ffedc587aa0923cbb3d762551bef060fab Mon Sep 17 00:00:00 2001 From: tankya2 Date: Mon, 13 Feb 2023 14:14:53 +0800 Subject: [PATCH] Created run_bench to get rid of repeated test code --- src/qibotn/__main__.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/src/qibotn/__main__.py b/src/qibotn/__main__.py index 61b61e8..e6dad2f 100644 --- a/src/qibotn/__main__.py +++ b/src/qibotn/__main__.py @@ -21,6 +21,16 @@ def parser(): return parser.parse_args() +def run_bench(task, label): + + start = timer() + result = task() + end = timer() + circuit_eval_time = end - start + print(f"Simulation time: {label} = {circuit_eval_time}s") + + return result + def main(args: argparse.Namespace): print("Testing for %d nqubits" % (args.nqubits)) @@ -35,20 +45,10 @@ def main(args: argparse.Namespace): raise NotImplementedError(f"Cannot find circuit {circuit_name}.") myconvertor = QiboCircuitToEinsum(circuit, dtype=datatype) - expression, operands = myconvertor.state_vector() - start = timer() - result_qibo = circuit() - end = timer() - circuit_eval_time = end - start - print("Simulation time: Qibo =",circuit_eval_time, 's') - - start = timer() - sv_cutn = contract(expression, *operands) - end = timer() - circuit_eval_time = end - start - print("Simulation time: cuQuantum cuTensorNet =",circuit_eval_time, 's') + result_qibo = run_bench(circuit, "Qibo") + sv_cutn = run_bench(lambda:contract(expression, *operands), "cuQuantum cuTensorNet") #print(f"is sv in agreement?", cp.allclose(sv_cutn.flatten(), result_qibo.state(numpy=True))) assert cp.allclose(sv_cutn.flatten(), result_qibo.state(numpy=True))