更新脚本和后端
Some checks failed
Build wheels / build (ubuntu-latest, 3.11) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.12) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.13) (push) Has been cancelled
Tests / check (push) Has been cancelled
Tests / build (ubuntu-latest, 3.11) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.12) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.13) (push) Has been cancelled
Some checks failed
Build wheels / build (ubuntu-latest, 3.11) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.12) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.13) (push) Has been cancelled
Tests / check (push) Has been cancelled
Tests / build (ubuntu-latest, 3.11) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.12) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.13) (push) Has been cancelled
This commit is contained in:
@@ -65,12 +65,18 @@ def main():
|
|||||||
parser.add_argument("--seed", type=int, default=42)
|
parser.add_argument("--seed", type=int, default=42)
|
||||||
parser.add_argument("--cut-ratio", type=float, default=1e-12)
|
parser.add_argument("--cut-ratio", type=float, default=1e-12)
|
||||||
parser.add_argument("--svd-control", default="V")
|
parser.add_argument("--svd-control", default="V")
|
||||||
|
parser.add_argument("--tensor-module", choices=("numpy", "torch"), default="numpy")
|
||||||
|
parser.add_argument("--torch-threads", type=int)
|
||||||
parser.add_argument("--exact", action="store_true")
|
parser.add_argument("--exact", action="store_true")
|
||||||
parser.add_argument("--exact-max-qubits", type=int, default=24)
|
parser.add_argument("--exact-max-qubits", type=int, default=24)
|
||||||
parser.add_argument("--preprocess", action="store_true")
|
parser.add_argument("--preprocess", action="store_true")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
logging.getLogger("qibo.config").setLevel(logging.ERROR)
|
logging.getLogger("qibo.config").setLevel(logging.ERROR)
|
||||||
logging.getLogger("qtealeaves").setLevel(logging.ERROR)
|
logging.getLogger("qtealeaves").setLevel(logging.ERROR)
|
||||||
|
if args.torch_threads is not None:
|
||||||
|
import torch
|
||||||
|
|
||||||
|
torch.set_num_threads(args.torch_threads)
|
||||||
|
|
||||||
circuit = build_circuit(args.nqubits, args.nlayers, args.seed)
|
circuit = build_circuit(args.nqubits, args.nlayers, args.seed)
|
||||||
observable = build_observable(args.nqubits)
|
observable = build_observable(args.nqubits)
|
||||||
@@ -84,7 +90,8 @@ def main():
|
|||||||
|
|
||||||
print(
|
print(
|
||||||
f"nqubits={args.nqubits} nlayers={args.nlayers} "
|
f"nqubits={args.nqubits} nlayers={args.nlayers} "
|
||||||
f"seed={args.seed} preprocess={args.preprocess}"
|
f"seed={args.seed} preprocess={args.preprocess} "
|
||||||
|
f"tensor_module={args.tensor_module}"
|
||||||
)
|
)
|
||||||
if exact is not None:
|
if exact is not None:
|
||||||
print(f"exact={exact:.16e}")
|
print(f"exact={exact:.16e}")
|
||||||
@@ -97,6 +104,7 @@ def main():
|
|||||||
max_bond_dimension=bond,
|
max_bond_dimension=bond,
|
||||||
cut_ratio=args.cut_ratio,
|
cut_ratio=args.cut_ratio,
|
||||||
svd_control=args.svd_control,
|
svd_control=args.svd_control,
|
||||||
|
tensor_module=args.tensor_module,
|
||||||
)
|
)
|
||||||
start = time.perf_counter()
|
start = time.perf_counter()
|
||||||
value = float(
|
value = float(
|
||||||
|
|||||||
@@ -39,6 +39,7 @@ class QMatchaTeaBackend(QibotnBackend, NumpyBackend):
|
|||||||
trunc_tracking_mode: str = "C",
|
trunc_tracking_mode: str = "C",
|
||||||
svd_control: str = "A",
|
svd_control: str = "A",
|
||||||
ini_bond_dimension: int = 1,
|
ini_bond_dimension: int = 1,
|
||||||
|
tensor_module: str = "numpy",
|
||||||
):
|
):
|
||||||
"""Configure TN simulation given Quantum Matcha Tea interface.
|
"""Configure TN simulation given Quantum Matcha Tea interface.
|
||||||
|
|
||||||
@@ -76,6 +77,7 @@ class QMatchaTeaBackend(QibotnBackend, NumpyBackend):
|
|||||||
ini_bond_dimension=ini_bond_dimension,
|
ini_bond_dimension=ini_bond_dimension,
|
||||||
)
|
)
|
||||||
self.ansatz = ansatz
|
self.ansatz = ansatz
|
||||||
|
self.tensor_module = tensor_module
|
||||||
if hasattr(self, "qmatchatea_backend"):
|
if hasattr(self, "qmatchatea_backend"):
|
||||||
self._setup_backend_specifics()
|
self._setup_backend_specifics()
|
||||||
|
|
||||||
@@ -96,6 +98,7 @@ class QMatchaTeaBackend(QibotnBackend, NumpyBackend):
|
|||||||
precision=qmatchatea_precision,
|
precision=qmatchatea_precision,
|
||||||
device=qmatchatea_device,
|
device=qmatchatea_device,
|
||||||
ansatz=self.ansatz,
|
ansatz=self.ansatz,
|
||||||
|
tensor_module=self.tensor_module,
|
||||||
)
|
)
|
||||||
|
|
||||||
def execute_circuit(
|
def execute_circuit(
|
||||||
|
|||||||
Reference in New Issue
Block a user