diff --git a/src/qibotn/qasm_quimb.py b/src/qibotn/qasm_quimb.py index 882b06d..dd94db8 100644 --- a/src/qibotn/qasm_quimb.py +++ b/src/qibotn/qasm_quimb.py @@ -173,6 +173,19 @@ def qasm_QFT(nqubits: int, qasm_str: str, with_swaps: bool = True, psi0=None): return circ +def init_state_tn(nqubits, init_state_sv, tn_lib="quimb"): + dims = tuple(2 * np.ones(nqubits, dtype=int)) + + if tn_lib == "quimb": + init_state_MPS = qtn.tensor_1d.MatrixProductState.from_dense( + init_state_sv, dims) + else: + # TODO: Add cuquantum later + assert False, "Unsupported tensor network backend in initilization" + + return init_state_MPS + + def tn_circ_eval(nqubits, qasm_circ, init_state, swaps=True, tn_lib="quimb", backend='numpy'): if tn_lib == "quimb": @@ -207,10 +220,8 @@ def eval_QI_qft(nqubits, backend="numpy", qibo_backend="qibojit", with_swaps=Tru ##################################################################### # Quimb circuit - # convert vector to MPS - dims = tuple(2 * np.ones(nqubits, dtype=int)) - init_state_MPS = qtn.tensor_1d.MatrixProductState.from_dense( - init_state_quimb, dims) + init_state_MPS = init_state_tn(nqubits=nqubits, + init_state_sv=init_state_quimb) # construct quimb qft circuit start = timer()