feat: dynamically inheriting with quimb
This commit is contained in:
@@ -13,7 +13,7 @@ class MetaBackend:
|
|||||||
"""Meta-backend class which takes care of loading the qibotn backends."""
|
"""Meta-backend class which takes care of loading the qibotn backends."""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def load(platform: str, runcard: dict = None) -> QibotnBackend:
|
def load(platform: str, runcard: dict = None, **kwargs) -> QibotnBackend:
|
||||||
"""Loads the backend.
|
"""Loads the backend.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -26,7 +26,8 @@ class MetaBackend:
|
|||||||
if platform == "cutensornet": # pragma: no cover
|
if platform == "cutensornet": # pragma: no cover
|
||||||
return CuTensorNet(runcard)
|
return CuTensorNet(runcard)
|
||||||
elif platform == "quimb": # pragma: no cover
|
elif platform == "quimb": # pragma: no cover
|
||||||
return QuimbBackend()
|
quimb_backend = kwargs.get("quimb_backend", "numpy")
|
||||||
|
return QuimbBackend(quimb_backend)
|
||||||
elif platform == "qmatchatea": # pragma: no cover
|
elif platform == "qmatchatea": # pragma: no cover
|
||||||
from qibotn.backends.qmatchatea import QMatchaTeaBackend
|
from qibotn.backends.qmatchatea import QMatchaTeaBackend
|
||||||
|
|
||||||
|
|||||||
@@ -39,10 +39,12 @@ GATE_MAP = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class QuimbBackend(QibotnBackend, NumpyBackend):
|
# class QuimbBackend(QibotnBackend, NumpyBackend):
|
||||||
|
if not __name__ == "__main__":
|
||||||
|
|
||||||
def __init__(self, engine="numpy"):
|
def __init__(self, engine="numpy"):
|
||||||
super().__init__()
|
# uper().__init__()
|
||||||
|
super(self.__class__, self).__init__()
|
||||||
|
|
||||||
self.name = "qibotn"
|
self.name = "qibotn"
|
||||||
self.platform = "quimb"
|
self.platform = "quimb"
|
||||||
@@ -85,6 +87,7 @@ class QuimbBackend(QibotnBackend, NumpyBackend):
|
|||||||
contractions_optimizer: str, optional
|
contractions_optimizer: str, optional
|
||||||
The contractions_optimizer to use for the quimb tensor network simulation.
|
The contractions_optimizer to use for the quimb tensor network simulation.
|
||||||
"""
|
"""
|
||||||
|
# this is not really working because it does not change the inheritance
|
||||||
if quimb_backend == "jax":
|
if quimb_backend == "jax":
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
|
|
||||||
@@ -142,7 +145,6 @@ class QuimbBackend(QibotnBackend, NumpyBackend):
|
|||||||
- If `initial_state` is provided, it must be compatible with the MPS ansatz.
|
- If `initial_state` is provided, it must be compatible with the MPS ansatz.
|
||||||
- The `nshots` parameter enables sampling from the circuit's output distribution. If not specified, the full statevector is computed.
|
- The `nshots` parameter enables sampling from the circuit's output distribution. If not specified, the full statevector is computed.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if initial_state is not None and self.ansatz == "MPS":
|
if initial_state is not None and self.ansatz == "MPS":
|
||||||
initial_state = qtn.tensor_1d.MatrixProductState.from_dense(
|
initial_state = qtn.tensor_1d.MatrixProductState.from_dense(
|
||||||
initial_state, 2
|
initial_state, 2
|
||||||
@@ -304,3 +306,32 @@ class QuimbBackend(QibotnBackend, NumpyBackend):
|
|||||||
for c in op_str[1:]:
|
for c in op_str[1:]:
|
||||||
op = op & qu.pauli(c)
|
op = op & qu.pauli(c)
|
||||||
return op
|
return op
|
||||||
|
|
||||||
|
|
||||||
|
def QuimbBackend(quimb_backend: str = "numpy") -> QibotnBackend:
|
||||||
|
bases = (QibotnBackend,)
|
||||||
|
methods = {
|
||||||
|
"__init__": __init__,
|
||||||
|
"configure_tn_simulation": configure_tn_simulation,
|
||||||
|
"setup_backend_specifics": setup_backend_specifics,
|
||||||
|
"execute_circuit": execute_circuit,
|
||||||
|
"expectation_observable_symbolic_from_state": expectation_observable_symbolic_from_state,
|
||||||
|
"_qibo_circuit_to_quimb": _qibo_circuit_to_quimb,
|
||||||
|
"_string_to_quimb_operator": _string_to_quimb_operator,
|
||||||
|
}
|
||||||
|
if quimb_backend == "numpy":
|
||||||
|
from qibo.backends import NumpyBackend
|
||||||
|
|
||||||
|
bases += (NumpyBackend,)
|
||||||
|
elif quimb_backend == "torch":
|
||||||
|
from qiboml.backends import PyTorchBackend
|
||||||
|
|
||||||
|
bases += (PyTorchBackend,)
|
||||||
|
elif quimb_backend == "jax":
|
||||||
|
from qiboml.backends import JaxBackend
|
||||||
|
|
||||||
|
bases += (JaxBackend,)
|
||||||
|
else:
|
||||||
|
raise_error(ValueError, f"Unsupported quimb backend: {quimb_backend}")
|
||||||
|
|
||||||
|
return type("QuimbBackend", bases, methods)(quimb_backend)
|
||||||
|
|||||||
Reference in New Issue
Block a user