Created run_bench to get rid of repeated test code

This commit is contained in:
tankya2
2023-02-13 14:14:53 +08:00
committed by Alessandro Candido
parent 6838faba33
commit bed3a50be5

View File

@@ -5,7 +5,7 @@ from qibotn import quimb as qiboquimb
from QiboCircuitConvertor import QiboCircuitToEinsum from QiboCircuitConvertor import QiboCircuitToEinsum
from cuquantum import contract from cuquantum import contract
import cupy as cp import cupy as cp
from qibo.models import * from qibo.models import QFT
def parser(): def parser():
@@ -47,6 +47,16 @@ def parser_cuquantum():
return parser.parse_args() return parser.parse_args()
def run_bench(task, label):
start = timer()
result = task()
end = timer()
circuit_eval_time = end - start
print(f"Simulation time: {label} = {circuit_eval_time}s")
return result
def main_cuquantum(args: argparse.Namespace): def main_cuquantum(args: argparse.Namespace):
print("Testing for %d nqubits" % (args.nqubits)) print("Testing for %d nqubits" % (args.nqubits))
nqubits = args.nqubits nqubits = args.nqubits
@@ -60,20 +70,12 @@ def main_cuquantum(args: argparse.Namespace):
raise NotImplementedError(f"Cannot find circuit {circuit_name}.") raise NotImplementedError(f"Cannot find circuit {circuit_name}.")
myconvertor = QiboCircuitToEinsum(circuit, dtype=datatype) myconvertor = QiboCircuitToEinsum(circuit, dtype=datatype)
expression, operands = myconvertor.state_vector() expression, operands = myconvertor.state_vector()
start = timer() result_qibo = run_bench(circuit, "Qibo")
result_qibo = circuit() sv_cutn = run_bench(
end = timer() lambda: contract(expression, *operands), "cuQuantum cuTensorNet"
circuit_eval_time = end - start )
print("Simulation time: Qibo =", circuit_eval_time, "s")
start = timer()
sv_cutn = contract(expression, *operands)
end = timer()
circuit_eval_time = end - start
print("Simulation time: cuQuantum cuTensorNet =", circuit_eval_time, "s")
# print(f"is sv in agreement?", cp.allclose(sv_cutn.flatten(), result_qibo.state(numpy=True))) # print(f"is sv in agreement?", cp.allclose(sv_cutn.flatten(), result_qibo.state(numpy=True)))
assert cp.allclose(sv_cutn.flatten(), result_qibo.state(numpy=True)) assert cp.allclose(sv_cutn.flatten(), result_qibo.state(numpy=True))