feat: dynamically inheriting with quimb

This commit is contained in:
BrunoLiegiBastonLiegi
2025-09-23 17:57:00 +02:00
parent 2b5fca800c
commit ab10b13d9b
2 changed files with 37 additions and 5 deletions

View File

@@ -13,7 +13,7 @@ class MetaBackend:
"""Meta-backend class which takes care of loading the qibotn backends."""
@staticmethod
def load(platform: str, runcard: dict = None) -> QibotnBackend:
def load(platform: str, runcard: dict = None, **kwargs) -> QibotnBackend:
"""Loads the backend.
Args:
@@ -26,7 +26,8 @@ class MetaBackend:
if platform == "cutensornet": # pragma: no cover
return CuTensorNet(runcard)
elif platform == "quimb": # pragma: no cover
return QuimbBackend()
quimb_backend = kwargs.get("quimb_backend", "numpy")
return QuimbBackend(quimb_backend)
elif platform == "qmatchatea": # pragma: no cover
from qibotn.backends.qmatchatea import QMatchaTeaBackend

View File

@@ -39,10 +39,12 @@ GATE_MAP = {
}
class QuimbBackend(QibotnBackend, NumpyBackend):
# class QuimbBackend(QibotnBackend, NumpyBackend):
if not __name__ == "__main__":
def __init__(self, engine="numpy"):
super().__init__()
# uper().__init__()
super(self.__class__, self).__init__()
self.name = "qibotn"
self.platform = "quimb"
@@ -85,6 +87,7 @@ class QuimbBackend(QibotnBackend, NumpyBackend):
contractions_optimizer: str, optional
The contractions_optimizer to use for the quimb tensor network simulation.
"""
# this is not really working because it does not change the inheritance
if quimb_backend == "jax":
import jax.numpy as jnp
@@ -142,7 +145,6 @@ class QuimbBackend(QibotnBackend, NumpyBackend):
- If `initial_state` is provided, it must be compatible with the MPS ansatz.
- The `nshots` parameter enables sampling from the circuit's output distribution. If not specified, the full statevector is computed.
"""
if initial_state is not None and self.ansatz == "MPS":
initial_state = qtn.tensor_1d.MatrixProductState.from_dense(
initial_state, 2
@@ -304,3 +306,32 @@ class QuimbBackend(QibotnBackend, NumpyBackend):
for c in op_str[1:]:
op = op & qu.pauli(c)
return op
def QuimbBackend(quimb_backend: str = "numpy") -> QibotnBackend:
bases = (QibotnBackend,)
methods = {
"__init__": __init__,
"configure_tn_simulation": configure_tn_simulation,
"setup_backend_specifics": setup_backend_specifics,
"execute_circuit": execute_circuit,
"expectation_observable_symbolic_from_state": expectation_observable_symbolic_from_state,
"_qibo_circuit_to_quimb": _qibo_circuit_to_quimb,
"_string_to_quimb_operator": _string_to_quimb_operator,
}
if quimb_backend == "numpy":
from qibo.backends import NumpyBackend
bases += (NumpyBackend,)
elif quimb_backend == "torch":
from qiboml.backends import PyTorchBackend
bases += (PyTorchBackend,)
elif quimb_backend == "jax":
from qiboml.backends import JaxBackend
bases += (JaxBackend,)
else:
raise_error(ValueError, f"Unsupported quimb backend: {quimb_backend}")
return type("QuimbBackend", bases, methods)(quimb_backend)