Format with Black

This commit is contained in:
tankya2
2023-08-17 13:23:29 +08:00
parent 3fafe2b3ff
commit 89bdbfbe68
3 changed files with 89 additions and 77 deletions

View File

@@ -1,5 +1,6 @@
from cuquantum import contract, contract_path, CircuitToEinsum, tensor
class MPSContractionHelper:
"""
A helper class to compute various quantities for a given MPS.
@@ -35,7 +36,9 @@ class MPSContractionHelper:
self.num_qubits = num_qubits
self.bra_modes = [(2 * i, 2 * i + 1, 2 * i + 2) for i in range(num_qubits)]
offset = 2 * num_qubits + 1
self.ket_modes = [(i+offset, 2*i+1, i+1+offset) for i in range(num_qubits)]
self.ket_modes = [
(i + offset, 2 * i + 1, i + 1 + offset) for i in range(num_qubits)
]
def contract_norm(self, mps_tensors, options=None):
"""
@@ -52,7 +55,9 @@ class MPSContractionHelper:
"""
interleaved_inputs = []
for i, o in enumerate(mps_tensors):
interleaved_inputs.extend([o, self.bra_modes[i], o.conj(), self.ket_modes[i]])
interleaved_inputs.extend(
[o, self.bra_modes[i], o.conj(), self.ket_modes[i]]
)
interleaved_inputs.append([]) # output
return self._contract(interleaved_inputs, options=options).real
@@ -76,7 +81,9 @@ class MPSContractionHelper:
interleaved_inputs.append(output_modes) # output
return self._contract(interleaved_inputs, options=options)
def contract_expectation(self, mps_tensors, operator, qubits, options=None, normalize=False):
def contract_expectation(
self, mps_tensors, operator, qubits, options=None, normalize=False
):
"""
Contract the corresponding tensor network to form the state vector representation of the MPS.
@@ -117,7 +124,6 @@ class MPSContractionHelper:
return self._contract(interleaved_inputs, options=options) / norm
def _contract(self, interleaved_inputs, options=None):
path = contract_path(*interleaved_inputs, options=options)[0]
return contract(*interleaved_inputs, options=options, optimize={'path':path})
return contract(*interleaved_inputs, options=options, optimize={"path": path})

View File

@@ -2,6 +2,7 @@ import cupy as cp
from cuquantum.cutensornet.experimental import contract_decompose
from cuquantum import contract
def initial(num_qubits, dtype):
"""
Generate the MPS with an initial state of |00...00>
@@ -10,25 +11,23 @@ def initial(num_qubits, dtype):
mps_tensors = [state_tensor] * num_qubits
return mps_tensors
def mps_site_right_swap(
mps_tensors,
i,
**kwargs
):
def mps_site_right_swap(mps_tensors, i, **kwargs):
"""
Perform the swap operation between the ith and i+1th MPS tensors.
"""
# contraction followed by QR decomposition
a, _, b = contract_decompose('ipj,jqk->iqj,jpk', *mps_tensors[i:i+2], algorithm=kwargs.get('algorithm',None), options=kwargs.get('options',None))
a, _, b = contract_decompose(
"ipj,jqk->iqj,jpk",
*mps_tensors[i : i + 2],
algorithm=kwargs.get("algorithm", None),
options=kwargs.get("options", None)
)
mps_tensors[i : i + 2] = (a, b)
return mps_tensors
def apply_gate(
mps_tensors,
gate,
qubits,
**kwargs
):
def apply_gate(mps_tensors, gate, qubits, **kwargs):
"""
Apply the gate operand to the MPS tensors in-place.
@@ -52,7 +51,9 @@ def apply_gate(
if n_qubits == 1:
# single-qubit gate
i = qubits[0]
mps_tensors[i] = contract('ipj,qp->iqj', mps_tensors[i], gate, options=kwargs.get('options',None)) # in-place update
mps_tensors[i] = contract(
"ipj,qp->iqj", mps_tensors[i], gate, options=kwargs.get("options", None)
) # in-place update
elif n_qubits == 2:
# two-qubit gate
i, j = qubits
@@ -61,7 +62,13 @@ def apply_gate(
return apply_gate(mps_tensors, gate.transpose(1, 0, 3, 2), (j, i), **kwargs)
elif i + 1 == j:
# two adjacent qubits
a, _, b = contract_decompose('ipj,jqk,rspq->irj,jsk', *mps_tensors[i:i+2], gate, algorithm=kwargs.get('algorithm',None), options=kwargs.get('options',None))
a, _, b = contract_decompose(
"ipj,jqk,rspq->irj,jsk",
*mps_tensors[i : i + 2],
gate,
algorithm=kwargs.get("algorithm", None),
options=kwargs.get("options", None)
)
mps_tensors[i : i + 2] = (a, b) # in-place update
else:
# non-adjacent two-qubit gate

View File

@@ -44,8 +44,7 @@ class QiboCircuitToEinsum:
for key in qubits_frontier:
out_list.append(qubits_frontier[key])
operand_exp_interleave = [x for y in zip(
operands, mode_labels) for x in y]
operand_exp_interleave = [x for y in zip(operands, mode_labels) for x in y]
operand_exp_interleave.append(out_list)
return operand_exp_interleave