Changed contract() input to interleaved format
This commit is contained in:
@@ -1,8 +1,6 @@
|
||||
import cupy as cp
|
||||
import numpy as np
|
||||
|
||||
EINSUM_SYMBOLS_BASE = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
||||
|
||||
|
||||
class QiboCircuitToEinsum:
|
||||
def __init__(self, circuit, dtype="complex128"):
|
||||
@@ -54,21 +52,13 @@ class QiboCircuitToEinsum:
|
||||
operands = input_operands + gate_operands
|
||||
mode_labels += gate_mode_labels
|
||||
|
||||
expression = self._convert_mode_labels_to_expression(
|
||||
mode_labels, qubits_frontier
|
||||
)
|
||||
out_list = []
|
||||
for key in qubits_frontier:
|
||||
out_list.append(qubits_frontier[key])
|
||||
|
||||
return expression, operands
|
||||
|
||||
def _get_symbol(self, i):
|
||||
"""
|
||||
Return a Unicode as label for index.
|
||||
|
||||
.. note:: This function is adopted from `opt_einsum <https://optimized-einsum.readthedocs.io/en/stable/_modules/opt_einsum/parser.html#get_symbol>`_
|
||||
"""
|
||||
if i < 52:
|
||||
return EINSUM_SYMBOLS_BASE[i]
|
||||
return chr(i + 140)
|
||||
operand_exp_interleave = [x for y in zip(operands, mode_labels) for x in y]
|
||||
operand_exp_interleave.append(out_list)
|
||||
return operand_exp_interleave
|
||||
|
||||
def _init_mode_labels_from_qubits(self, qubits):
|
||||
frontier_dict = {}
|
||||
@@ -104,17 +94,3 @@ class QiboCircuitToEinsum:
|
||||
next_frontier += 1
|
||||
mode_labels.append(output_mode_labels + input_mode_labels)
|
||||
return mode_labels, operands
|
||||
|
||||
def _convert_mode_labels_to_expression(self, input_mode_labels, output_mode_labels):
|
||||
out_list = []
|
||||
for key in output_mode_labels:
|
||||
out_list.append(output_mode_labels[key])
|
||||
|
||||
input_symbols = [
|
||||
"".join(map(self._get_symbol, idx)) for idx in input_mode_labels
|
||||
]
|
||||
expression = (
|
||||
",".join(input_symbols) + "->" + "".join(map(self._get_symbol, out_list))
|
||||
)
|
||||
|
||||
return expression
|
||||
|
||||
@@ -47,7 +47,6 @@ def main(args: argparse.Namespace):
|
||||
nqubits = args.nqubits
|
||||
circuit_name = args.circuit
|
||||
datatype = args.precision
|
||||
# Create qibo quibit
|
||||
|
||||
if circuit_name in ("qft", "QFT"):
|
||||
circuit = QFT(nqubits)
|
||||
@@ -55,12 +54,10 @@ def main(args: argparse.Namespace):
|
||||
raise NotImplementedError(f"Cannot find circuit {circuit_name}.")
|
||||
|
||||
myconvertor = QiboCircuitToEinsum(circuit, dtype=datatype)
|
||||
expression, operands = myconvertor.state_vector()
|
||||
operands_expression = myconvertor.state_vector()
|
||||
|
||||
result_qibo = run_bench(circuit, "Qibo")
|
||||
sv_cutn = run_bench(
|
||||
lambda: contract(expression, *operands), "cuQuantum cuTensorNet"
|
||||
)
|
||||
sv_cutn = run_bench(lambda: contract(*operands_expression), "cuQuantum cuTensorNet")
|
||||
|
||||
# print(f"is sv in agreement?", cp.allclose(sv_cutn.flatten(), result_qibo.state(numpy=True)))
|
||||
assert cp.allclose(sv_cutn.flatten(), result_qibo.state(numpy=True))
|
||||
|
||||
Reference in New Issue
Block a user