Changed contract() input to interleaved format
This commit is contained in:
@@ -1,8 +1,6 @@
|
|||||||
import cupy as cp
|
import cupy as cp
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
EINSUM_SYMBOLS_BASE = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
|
||||||
|
|
||||||
|
|
||||||
class QiboCircuitToEinsum:
|
class QiboCircuitToEinsum:
|
||||||
def __init__(self, circuit, dtype="complex128"):
|
def __init__(self, circuit, dtype="complex128"):
|
||||||
@@ -54,21 +52,13 @@ class QiboCircuitToEinsum:
|
|||||||
operands = input_operands + gate_operands
|
operands = input_operands + gate_operands
|
||||||
mode_labels += gate_mode_labels
|
mode_labels += gate_mode_labels
|
||||||
|
|
||||||
expression = self._convert_mode_labels_to_expression(
|
out_list = []
|
||||||
mode_labels, qubits_frontier
|
for key in qubits_frontier:
|
||||||
)
|
out_list.append(qubits_frontier[key])
|
||||||
|
|
||||||
return expression, operands
|
operand_exp_interleave = [x for y in zip(operands, mode_labels) for x in y]
|
||||||
|
operand_exp_interleave.append(out_list)
|
||||||
def _get_symbol(self, i):
|
return operand_exp_interleave
|
||||||
"""
|
|
||||||
Return a Unicode as label for index.
|
|
||||||
|
|
||||||
.. note:: This function is adopted from `opt_einsum <https://optimized-einsum.readthedocs.io/en/stable/_modules/opt_einsum/parser.html#get_symbol>`_
|
|
||||||
"""
|
|
||||||
if i < 52:
|
|
||||||
return EINSUM_SYMBOLS_BASE[i]
|
|
||||||
return chr(i + 140)
|
|
||||||
|
|
||||||
def _init_mode_labels_from_qubits(self, qubits):
|
def _init_mode_labels_from_qubits(self, qubits):
|
||||||
frontier_dict = {}
|
frontier_dict = {}
|
||||||
@@ -104,17 +94,3 @@ class QiboCircuitToEinsum:
|
|||||||
next_frontier += 1
|
next_frontier += 1
|
||||||
mode_labels.append(output_mode_labels + input_mode_labels)
|
mode_labels.append(output_mode_labels + input_mode_labels)
|
||||||
return mode_labels, operands
|
return mode_labels, operands
|
||||||
|
|
||||||
def _convert_mode_labels_to_expression(self, input_mode_labels, output_mode_labels):
|
|
||||||
out_list = []
|
|
||||||
for key in output_mode_labels:
|
|
||||||
out_list.append(output_mode_labels[key])
|
|
||||||
|
|
||||||
input_symbols = [
|
|
||||||
"".join(map(self._get_symbol, idx)) for idx in input_mode_labels
|
|
||||||
]
|
|
||||||
expression = (
|
|
||||||
",".join(input_symbols) + "->" + "".join(map(self._get_symbol, out_list))
|
|
||||||
)
|
|
||||||
|
|
||||||
return expression
|
|
||||||
|
|||||||
@@ -47,7 +47,6 @@ def main(args: argparse.Namespace):
|
|||||||
nqubits = args.nqubits
|
nqubits = args.nqubits
|
||||||
circuit_name = args.circuit
|
circuit_name = args.circuit
|
||||||
datatype = args.precision
|
datatype = args.precision
|
||||||
# Create qibo quibit
|
|
||||||
|
|
||||||
if circuit_name in ("qft", "QFT"):
|
if circuit_name in ("qft", "QFT"):
|
||||||
circuit = QFT(nqubits)
|
circuit = QFT(nqubits)
|
||||||
@@ -55,12 +54,10 @@ def main(args: argparse.Namespace):
|
|||||||
raise NotImplementedError(f"Cannot find circuit {circuit_name}.")
|
raise NotImplementedError(f"Cannot find circuit {circuit_name}.")
|
||||||
|
|
||||||
myconvertor = QiboCircuitToEinsum(circuit, dtype=datatype)
|
myconvertor = QiboCircuitToEinsum(circuit, dtype=datatype)
|
||||||
expression, operands = myconvertor.state_vector()
|
operands_expression = myconvertor.state_vector()
|
||||||
|
|
||||||
result_qibo = run_bench(circuit, "Qibo")
|
result_qibo = run_bench(circuit, "Qibo")
|
||||||
sv_cutn = run_bench(
|
sv_cutn = run_bench(lambda: contract(*operands_expression), "cuQuantum cuTensorNet")
|
||||||
lambda: contract(expression, *operands), "cuQuantum cuTensorNet"
|
|
||||||
)
|
|
||||||
|
|
||||||
# print(f"is sv in agreement?", cp.allclose(sv_cutn.flatten(), result_qibo.state(numpy=True)))
|
# print(f"is sv in agreement?", cp.allclose(sv_cutn.flatten(), result_qibo.state(numpy=True)))
|
||||||
assert cp.allclose(sv_cutn.flatten(), result_qibo.state(numpy=True))
|
assert cp.allclose(sv_cutn.flatten(), result_qibo.state(numpy=True))
|
||||||
|
|||||||
Reference in New Issue
Block a user