From ab10b13d9b7e4691c7b4d58a88c555e61d57ccc7 Mon Sep 17 00:00:00 2001 From: BrunoLiegiBastonLiegi Date: Tue, 23 Sep 2025 17:57:00 +0200 Subject: [PATCH] feat: dynamically inheriting with quimb --- src/qibotn/backends/__init__.py | 5 +++-- src/qibotn/backends/quimb.py | 37 ++++++++++++++++++++++++++++++--- 2 files changed, 37 insertions(+), 5 deletions(-) diff --git a/src/qibotn/backends/__init__.py b/src/qibotn/backends/__init__.py index 416a384..4a5dbdd 100644 --- a/src/qibotn/backends/__init__.py +++ b/src/qibotn/backends/__init__.py @@ -13,7 +13,7 @@ class MetaBackend: """Meta-backend class which takes care of loading the qibotn backends.""" @staticmethod - def load(platform: str, runcard: dict = None) -> QibotnBackend: + def load(platform: str, runcard: dict = None, **kwargs) -> QibotnBackend: """Loads the backend. Args: @@ -26,7 +26,8 @@ class MetaBackend: if platform == "cutensornet": # pragma: no cover return CuTensorNet(runcard) elif platform == "quimb": # pragma: no cover - return QuimbBackend() + quimb_backend = kwargs.get("quimb_backend", "numpy") + return QuimbBackend(quimb_backend) elif platform == "qmatchatea": # pragma: no cover from qibotn.backends.qmatchatea import QMatchaTeaBackend diff --git a/src/qibotn/backends/quimb.py b/src/qibotn/backends/quimb.py index 6e4629b..658ca7e 100644 --- a/src/qibotn/backends/quimb.py +++ b/src/qibotn/backends/quimb.py @@ -39,10 +39,12 @@ GATE_MAP = { } -class QuimbBackend(QibotnBackend, NumpyBackend): +# class QuimbBackend(QibotnBackend, NumpyBackend): +if not __name__ == "__main__": def __init__(self, engine="numpy"): - super().__init__() + # uper().__init__() + super(self.__class__, self).__init__() self.name = "qibotn" self.platform = "quimb" @@ -85,6 +87,7 @@ class QuimbBackend(QibotnBackend, NumpyBackend): contractions_optimizer: str, optional The contractions_optimizer to use for the quimb tensor network simulation. """ + # this is not really working because it does not change the inheritance if quimb_backend == "jax": import jax.numpy as jnp @@ -142,7 +145,6 @@ class QuimbBackend(QibotnBackend, NumpyBackend): - If `initial_state` is provided, it must be compatible with the MPS ansatz. - The `nshots` parameter enables sampling from the circuit's output distribution. If not specified, the full statevector is computed. """ - if initial_state is not None and self.ansatz == "MPS": initial_state = qtn.tensor_1d.MatrixProductState.from_dense( initial_state, 2 @@ -304,3 +306,32 @@ class QuimbBackend(QibotnBackend, NumpyBackend): for c in op_str[1:]: op = op & qu.pauli(c) return op + + +def QuimbBackend(quimb_backend: str = "numpy") -> QibotnBackend: + bases = (QibotnBackend,) + methods = { + "__init__": __init__, + "configure_tn_simulation": configure_tn_simulation, + "setup_backend_specifics": setup_backend_specifics, + "execute_circuit": execute_circuit, + "expectation_observable_symbolic_from_state": expectation_observable_symbolic_from_state, + "_qibo_circuit_to_quimb": _qibo_circuit_to_quimb, + "_string_to_quimb_operator": _string_to_quimb_operator, + } + if quimb_backend == "numpy": + from qibo.backends import NumpyBackend + + bases += (NumpyBackend,) + elif quimb_backend == "torch": + from qiboml.backends import PyTorchBackend + + bases += (PyTorchBackend,) + elif quimb_backend == "jax": + from qiboml.backends import JaxBackend + + bases += (JaxBackend,) + else: + raise_error(ValueError, f"Unsupported quimb backend: {quimb_backend}") + + return type("QuimbBackend", bases, methods)(quimb_backend)