refactor: move the imports outside of the backend init

This commit is contained in:
MatteoRobbiati
2025-01-28 14:38:27 +01:00
parent 6fe2c32c0d
commit 91b4b63130
2 changed files with 32 additions and 31 deletions

View File

@@ -3,6 +3,16 @@ from abc import abstractmethod
from qibo.backends.numpy import NumpyBackend from qibo.backends.numpy import NumpyBackend
from qibo.config import raise_error from qibo.config import raise_error
DEFAULT_CONFIGURATION = {
"MPI_enabled": False, # TODO: cutensornet specific, TBRemoved
"NCCL_enabled": False, # TODO: cutensornet specific, TBRemoved
"expectation_enabled": False,
"pauli_string_pattern": None,
"MPS_enabled": False,
"gate_algo": None,
"mps_opts": None,
}
class QibotnBackend(NumpyBackend): class QibotnBackend(NumpyBackend):

View File

@@ -2,8 +2,10 @@
from dataclasses import dataclass from dataclasses import dataclass
import qiskit
import qmatchatea
import qtealeaves
from qibo.config import raise_error from qibo.config import raise_error
from qiskit import QuantumCircuit
from qibotn.backends.abstract import QibotnBackend from qibotn.backends.abstract import QibotnBackend
from qibotn.result import TensorNetworkResult from qibotn.result import TensorNetworkResult
@@ -18,21 +20,12 @@ class QMatchaTeaBackend(QibotnBackend):
self.name = "qiboml" self.name = "qiboml"
self.platform = "qmatchatea" self.platform = "qmatchatea"
import qiskit # pylint: disable=import-error
import qmatchatea # pylint: disable=import-error
import qtealeaves # pylint: disable=import-error
# TODO: move outside of the class
self.qmatchatea = qmatchatea
self.qiskit = qiskit
self.qtleaves = qtealeaves
# Set default configurations # Set default configurations
self.configure_tn_simulation() self.configure_tn_simulation()
# TODO: update this function whenever ``set_device`` and ``set_precision`` # TODO: update this function whenever ``set_device`` and ``set_precision``
# are set (?) # are set (?)
self._setup_qmatchatea_backend() self._setup_qmatchatea_backend()
self._observables = self.qtleaves.observables.TNObservables() self._observables = qtealeaves.observables.TNObservables()
@property @property
def observables(self): def observables(self):
@@ -45,9 +38,9 @@ class QMatchaTeaBackend(QibotnBackend):
It accepts a dict of objects among the ones proposed in ``qtealeaves.observables``. It accepts a dict of objects among the ones proposed in ``qtealeaves.observables``.
""" """
self._observables = self.qtleaves.observables.TNObservables() self._observables = qtealeaves.observables.TNObservables()
for obs in observables: for obs in observables:
if isinstance(obs, self.qtleaves.observables.tnobase._TNObsBase): if isinstance(obs, qtealeaves.observables.tnobase._TNObsBase):
self._observables += obs self._observables += obs
else: else:
raise TypeError("Expected an instance of TNObservables") raise TypeError("Expected an instance of TNObservables")
@@ -98,27 +91,25 @@ class QMatchaTeaBackend(QibotnBackend):
# TODO: check # TODO: check
circuit = self._qibocirc_to_qiskitcirc(circuit) circuit = self._qibocirc_to_qiskitcirc(circuit)
run_qk_params = self.qmatchatea.preprocessing.qk_transpilation_params(False) run_qk_params = qmatchatea.preprocessing.qk_transpilation_params(False)
# Initialize the TNObservable object # Initialize the TNObservable object
observables = self.qtleaves.observables.TNObservables() observables = qtealeaves.observables.TNObservables()
# Shots # Shots
if nshots is not None: if nshots is not None:
observables += self.qtleaves.observables.TNObsProjective(num_shots=nshots) observables += qtealeaves.observables.TNObsProjective(num_shots=nshots)
# Probabilities # Probabilities
observables += self.qtleaves.observables.TNObsProbabilities( observables += qtealeaves.observables.TNObsProbabilities(
prob_type=prob_type, prob_type=prob_type,
**prob_kwargs, **prob_kwargs,
) )
# State # State
observables += self.qtleaves.observables.TNState2File( observables += qtealeaves.observables.TNState2File(name="temp", formatting="U")
name="temp", formatting="U"
)
results = self.qmatchatea.run_simulation( results = qmatchatea.run_simulation(
circ=circuit, circ=circuit,
convergence_parameters=self.convergence_params, convergence_parameters=self.convergence_params,
transpilation_parameters=run_qk_params, transpilation_parameters=run_qk_params,
@@ -159,40 +150,40 @@ class QMatchaTeaBackend(QibotnBackend):
# Set configurations or defaults # Set configurations or defaults
self.convergence_params = ( self.convergence_params = (
convergence_params or self.qmatchatea.QCConvergenceParameters() convergence_params or qmatchatea.QCConvergenceParameters()
) )
self.ansatz = ansatz self.ansatz = ansatz
def _setup_qmatchatea_backend(self): def _setup_qmatchatea_backend(self):
"""Configure qmatchatea QCBackend object.""" """Configure qmatchatea QCBackend object."""
self.qmatchatea_device = ( qmatchatea_device = (
"cpu" if "CPU" in self.device else "gpu" if "GPU" in self.device else None "cpu" if "CPU" in self.device else "gpu" if "GPU" in self.device else None
) )
self.qmatchatea_precision = ( qmatchatea_precision = (
"C" "C"
if self.precision == "single" if self.precision == "single"
else "Z" if self.precision == "double" else "A" else "Z" if self.precision == "double" else "A"
) )
# TODO: once MPI is available for Python, integrate it here # TODO: once MPI is available for Python, integrate it here
self.qmatchatea_backend = self.qmatchatea.QCBackend( self.qmatchatea_backend = qmatchatea.QCBackend(
backend="PY", # The only alternative is Fortran, but we use Python here backend="PY", # The only alternative is Fortran, but we use Python here
precision=self.qmatchatea_precision, precision=qmatchatea_precision,
device=self.qmatchatea_device, device=qmatchatea_device,
ansatz=self.ansatz, ansatz=self.ansatz,
) )
def _qibocirc_to_qiskitcirc(self, qibo_circuit) -> QuantumCircuit: def _qibocirc_to_qiskitcirc(self, qibo_circuit) -> qiskit.QuantumCircuit:
"""Convert a Qibo Circuit into a Qiskit Circuit.""" """Convert a Qibo Circuit into a Qiskit Circuit."""
# Convert the circuit to QASM 2.0 to qiskit # Convert the circuit to QASM 2.0 to qiskit
qasm_circuit = qibo_circuit.to_qasm() qasm_circuit = qibo_circuit.to_qasm()
qiskit_circuit = QuantumCircuit.from_qasm_str(qasm_circuit) qiskit_circuit = qiskit.QuantumCircuit.from_qasm_str(qasm_circuit)
# Transpile the circuit to adapt it to the linear structure of the MPS, # Transpile the circuit to adapt it to the linear structure of the MPS,
# with the constraint of having only the gates basis_gates # with the constraint of having only the gates basis_gates
qiskit_circuit = self.qmatchatea.preprocessing.preprocess( qiskit_circuit = qmatchatea.preprocessing.preprocess(
qiskit_circuit, qiskit_circuit,
qk_params=self.qmatchatea.preprocessing.qk_transpilation_params(), qk_params=qmatchatea.preprocessing.qk_transpilation_params(),
) )
return qiskit_circuit return qiskit_circuit