feat: dynamically inheriting with quimb
This commit is contained in:
@@ -13,7 +13,7 @@ class MetaBackend:
|
||||
"""Meta-backend class which takes care of loading the qibotn backends."""
|
||||
|
||||
@staticmethod
|
||||
def load(platform: str, runcard: dict = None) -> QibotnBackend:
|
||||
def load(platform: str, runcard: dict = None, **kwargs) -> QibotnBackend:
|
||||
"""Loads the backend.
|
||||
|
||||
Args:
|
||||
@@ -26,7 +26,8 @@ class MetaBackend:
|
||||
if platform == "cutensornet": # pragma: no cover
|
||||
return CuTensorNet(runcard)
|
||||
elif platform == "quimb": # pragma: no cover
|
||||
return QuimbBackend()
|
||||
quimb_backend = kwargs.get("quimb_backend", "numpy")
|
||||
return QuimbBackend(quimb_backend)
|
||||
elif platform == "qmatchatea": # pragma: no cover
|
||||
from qibotn.backends.qmatchatea import QMatchaTeaBackend
|
||||
|
||||
|
||||
@@ -39,10 +39,12 @@ GATE_MAP = {
|
||||
}
|
||||
|
||||
|
||||
class QuimbBackend(QibotnBackend, NumpyBackend):
|
||||
# class QuimbBackend(QibotnBackend, NumpyBackend):
|
||||
if not __name__ == "__main__":
|
||||
|
||||
def __init__(self, engine="numpy"):
|
||||
super().__init__()
|
||||
# uper().__init__()
|
||||
super(self.__class__, self).__init__()
|
||||
|
||||
self.name = "qibotn"
|
||||
self.platform = "quimb"
|
||||
@@ -85,6 +87,7 @@ class QuimbBackend(QibotnBackend, NumpyBackend):
|
||||
contractions_optimizer: str, optional
|
||||
The contractions_optimizer to use for the quimb tensor network simulation.
|
||||
"""
|
||||
# this is not really working because it does not change the inheritance
|
||||
if quimb_backend == "jax":
|
||||
import jax.numpy as jnp
|
||||
|
||||
@@ -142,7 +145,6 @@ class QuimbBackend(QibotnBackend, NumpyBackend):
|
||||
- If `initial_state` is provided, it must be compatible with the MPS ansatz.
|
||||
- The `nshots` parameter enables sampling from the circuit's output distribution. If not specified, the full statevector is computed.
|
||||
"""
|
||||
|
||||
if initial_state is not None and self.ansatz == "MPS":
|
||||
initial_state = qtn.tensor_1d.MatrixProductState.from_dense(
|
||||
initial_state, 2
|
||||
@@ -304,3 +306,32 @@ class QuimbBackend(QibotnBackend, NumpyBackend):
|
||||
for c in op_str[1:]:
|
||||
op = op & qu.pauli(c)
|
||||
return op
|
||||
|
||||
|
||||
def QuimbBackend(quimb_backend: str = "numpy") -> QibotnBackend:
|
||||
bases = (QibotnBackend,)
|
||||
methods = {
|
||||
"__init__": __init__,
|
||||
"configure_tn_simulation": configure_tn_simulation,
|
||||
"setup_backend_specifics": setup_backend_specifics,
|
||||
"execute_circuit": execute_circuit,
|
||||
"expectation_observable_symbolic_from_state": expectation_observable_symbolic_from_state,
|
||||
"_qibo_circuit_to_quimb": _qibo_circuit_to_quimb,
|
||||
"_string_to_quimb_operator": _string_to_quimb_operator,
|
||||
}
|
||||
if quimb_backend == "numpy":
|
||||
from qibo.backends import NumpyBackend
|
||||
|
||||
bases += (NumpyBackend,)
|
||||
elif quimb_backend == "torch":
|
||||
from qiboml.backends import PyTorchBackend
|
||||
|
||||
bases += (PyTorchBackend,)
|
||||
elif quimb_backend == "jax":
|
||||
from qiboml.backends import JaxBackend
|
||||
|
||||
bases += (JaxBackend,)
|
||||
else:
|
||||
raise_error(ValueError, f"Unsupported quimb backend: {quimb_backend}")
|
||||
|
||||
return type("QuimbBackend", bases, methods)(quimb_backend)
|
||||
|
||||
Reference in New Issue
Block a user