更新脚本和后端
Some checks failed
Build wheels / build (ubuntu-latest, 3.11) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.12) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.13) (push) Has been cancelled
Tests / check (push) Has been cancelled
Tests / build (ubuntu-latest, 3.11) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.12) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.13) (push) Has been cancelled

This commit is contained in:
2026-05-09 18:36:23 +08:00
parent 7cebbb0820
commit ff96e36cfc
2 changed files with 12 additions and 1 deletions

View File

@@ -65,12 +65,18 @@ def main():
parser.add_argument("--seed", type=int, default=42) parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--cut-ratio", type=float, default=1e-12) parser.add_argument("--cut-ratio", type=float, default=1e-12)
parser.add_argument("--svd-control", default="V") parser.add_argument("--svd-control", default="V")
parser.add_argument("--tensor-module", choices=("numpy", "torch"), default="numpy")
parser.add_argument("--torch-threads", type=int)
parser.add_argument("--exact", action="store_true") parser.add_argument("--exact", action="store_true")
parser.add_argument("--exact-max-qubits", type=int, default=24) parser.add_argument("--exact-max-qubits", type=int, default=24)
parser.add_argument("--preprocess", action="store_true") parser.add_argument("--preprocess", action="store_true")
args = parser.parse_args() args = parser.parse_args()
logging.getLogger("qibo.config").setLevel(logging.ERROR) logging.getLogger("qibo.config").setLevel(logging.ERROR)
logging.getLogger("qtealeaves").setLevel(logging.ERROR) logging.getLogger("qtealeaves").setLevel(logging.ERROR)
if args.torch_threads is not None:
import torch
torch.set_num_threads(args.torch_threads)
circuit = build_circuit(args.nqubits, args.nlayers, args.seed) circuit = build_circuit(args.nqubits, args.nlayers, args.seed)
observable = build_observable(args.nqubits) observable = build_observable(args.nqubits)
@@ -84,7 +90,8 @@ def main():
print( print(
f"nqubits={args.nqubits} nlayers={args.nlayers} " f"nqubits={args.nqubits} nlayers={args.nlayers} "
f"seed={args.seed} preprocess={args.preprocess}" f"seed={args.seed} preprocess={args.preprocess} "
f"tensor_module={args.tensor_module}"
) )
if exact is not None: if exact is not None:
print(f"exact={exact:.16e}") print(f"exact={exact:.16e}")
@@ -97,6 +104,7 @@ def main():
max_bond_dimension=bond, max_bond_dimension=bond,
cut_ratio=args.cut_ratio, cut_ratio=args.cut_ratio,
svd_control=args.svd_control, svd_control=args.svd_control,
tensor_module=args.tensor_module,
) )
start = time.perf_counter() start = time.perf_counter()
value = float( value = float(

View File

@@ -39,6 +39,7 @@ class QMatchaTeaBackend(QibotnBackend, NumpyBackend):
trunc_tracking_mode: str = "C", trunc_tracking_mode: str = "C",
svd_control: str = "A", svd_control: str = "A",
ini_bond_dimension: int = 1, ini_bond_dimension: int = 1,
tensor_module: str = "numpy",
): ):
"""Configure TN simulation given Quantum Matcha Tea interface. """Configure TN simulation given Quantum Matcha Tea interface.
@@ -76,6 +77,7 @@ class QMatchaTeaBackend(QibotnBackend, NumpyBackend):
ini_bond_dimension=ini_bond_dimension, ini_bond_dimension=ini_bond_dimension,
) )
self.ansatz = ansatz self.ansatz = ansatz
self.tensor_module = tensor_module
if hasattr(self, "qmatchatea_backend"): if hasattr(self, "qmatchatea_backend"):
self._setup_backend_specifics() self._setup_backend_specifics()
@@ -96,6 +98,7 @@ class QMatchaTeaBackend(QibotnBackend, NumpyBackend):
precision=qmatchatea_precision, precision=qmatchatea_precision,
device=qmatchatea_device, device=qmatchatea_device,
ansatz=self.ansatz, ansatz=self.ansatz,
tensor_module=self.tensor_module,
) )
def execute_circuit( def execute_circuit(