diff --git a/src/qibotn/qasm_quimb.py b/src/qibotn/qasm_quimb.py index c0193a3..882b06d 100644 --- a/src/qibotn/qasm_quimb.py +++ b/src/qibotn/qasm_quimb.py @@ -173,6 +173,18 @@ def qasm_QFT(nqubits: int, qasm_str: str, with_swaps: bool = True, psi0=None): return circ +def tn_circ_eval(nqubits, qasm_circ, init_state, swaps=True, tn_lib="quimb", + backend='numpy'): + if tn_lib == "quimb": + circ_quimb = qasm_QFT(nqubits, qasm_circ, swaps, psi0=init_state) + interim = circ_quimb.psi.full_simplify(seq="DRC") + result = interim.to_dense(backend=backend).flatten() + return result + else: + # TODO: Change assert or value. Add cuquantum later + assert False, "Unsupported tensor network library" + + def eval_QI_qft(nqubits, backend="numpy", qibo_backend="qibojit", with_swaps=True): # backend (quimb): numpy, cupy, jax. Passed to ``opt_einsum``. # qibo_backend: qibojit, qibotf, tensorflow, numpy @@ -197,17 +209,14 @@ def eval_QI_qft(nqubits, backend="numpy", qibo_backend="qibojit", with_swaps=Tru # Quimb circuit # convert vector to MPS dims = tuple(2 * np.ones(nqubits, dtype=int)) - init_state_MPS = qtn.tensor_1d.MatrixProductState.from_dense(init_state_quimb, dims) + init_state_MPS = qtn.tensor_1d.MatrixProductState.from_dense( + init_state_quimb, dims) # construct quimb qft circuit start = timer() - circ_quimb = qasm_QFT(nqubits, qasm_circ, with_swaps, psi0=init_state_MPS) - - interim = circ_quimb.psi.full_simplify(seq="DRC") - - result = interim.to_dense(backend=backend) - amplitudes = result.flatten() + amplitudes = tn_circ_eval(nqubits=nqubits, qasm_circ=qasm_circ, + init_state=init_state_MPS, swaps=with_swaps, + tn_lib="quimb") end = timer() quimb_qft_time = end - start - print("quimb time is " + str(quimb_qft_time)) assert np.allclose(amplitudes, amplitudes_reference, atol=1e-06)