Remove path cache

This commit is contained in:
tankya2
2023-08-17 10:53:56 +08:00
parent cb21d1d1c0
commit 3fafe2b3ff

View File

@@ -33,7 +33,6 @@ class MPSContractionHelper:
def __init__(self, num_qubits):
self.num_qubits = num_qubits
self.path_cache = {}
self.bra_modes = [(2*i, 2*i+1, 2*i+2) for i in range(num_qubits)]
offset = 2*num_qubits+1
self.ket_modes = [(i+offset, 2*i+1, i+1+offset) for i in range(num_qubits)]
@@ -55,7 +54,7 @@ class MPSContractionHelper:
for i, o in enumerate(mps_tensors):
interleaved_inputs.extend([o, self.bra_modes[i], o.conj(), self.ket_modes[i]])
interleaved_inputs.append([]) # output
return self._contract('norm', interleaved_inputs, options=options).real
return self._contract(interleaved_inputs, options=options).real
def contract_state_vector(self, mps_tensors, options=None):
"""
@@ -75,7 +74,7 @@ class MPSContractionHelper:
interleaved_inputs.extend([o, self.bra_modes[i]])
output_modes = tuple([bra_modes[1] for bra_modes in self.bra_modes])
interleaved_inputs.append(output_modes) # output
return self._contract('sv', interleaved_inputs, options=options)
return self._contract(interleaved_inputs, options=options)
def contract_expectation(self, mps_tensors, operator, qubits, options=None, normalize=False):
"""
@@ -115,13 +114,10 @@ class MPSContractionHelper:
norm = self.contract_norm(mps_tensors, options=options)
else:
norm = 1
return self._contract(f'exp{qubits}', interleaved_inputs, options=options) / norm
return self._contract(interleaved_inputs, options=options) / norm
def _contract(self, key, interleaved_inputs, options=None):
"""
Perform the contraction task given interleaved inputs. Path will be cached.
"""
if key not in self.path_cache:
self.path_cache[key] = contract_path(*interleaved_inputs, options=options)[0]
path = self.path_cache[key]
def _contract(self, interleaved_inputs, options=None):
path = contract_path(*interleaved_inputs, options=options)[0]
return contract(*interleaved_inputs, options=options, optimize={'path':path})