clean up reduce_result: minor refactor
This commit is contained in:
@@ -31,7 +31,6 @@ def build_observable(circuit_nqubit):
|
||||
for i in range(circuit_nqubit):
|
||||
hamiltonian_form += 0.5 * X(i % circuit_nqubit) * Z((i + 1) % circuit_nqubit)
|
||||
|
||||
print("Default hamiltonian: ", hamiltonian_form)
|
||||
hamiltonian = hamiltonians.SymbolicHamiltonian(form=hamiltonian_form)
|
||||
return hamiltonian
|
||||
|
||||
@@ -195,11 +194,6 @@ def compute_optimal_path(network, n_samples, size, comm):
|
||||
return comm.bcast(info, sender)
|
||||
|
||||
|
||||
def compute_contraction(network, slices):
|
||||
"""Perform tensor contraction."""
|
||||
return network.contract(slices=slices)
|
||||
|
||||
|
||||
def compute_slices(info, rank, size):
|
||||
"""Determine the slice range each process should compute."""
|
||||
num_slices = info.num_slices
|
||||
@@ -215,6 +209,7 @@ def reduce_result(result, comm, method="MPI", root=0):
|
||||
"""Reduce results across processes."""
|
||||
if method == "MPI":
|
||||
return comm.reduce(sendobj=result, op=MPI.SUM, root=root)
|
||||
|
||||
elif method == "NCCL":
|
||||
stream_ptr = cp.cuda.get_current_stream().ptr
|
||||
if result.dtype == cp.complex128:
|
||||
@@ -236,6 +231,8 @@ def reduce_result(result, comm, method="MPI", root=0):
|
||||
stream_ptr,
|
||||
)
|
||||
return result
|
||||
else:
|
||||
raise ValueError(f"Unknown reduce method: {method}")
|
||||
|
||||
|
||||
def dense_vector_tn_MPI(qibo_circ, datatype, n_samples=8):
|
||||
@@ -265,7 +262,7 @@ def dense_vector_tn_MPI(qibo_circ, datatype, n_samples=8):
|
||||
optimize={"path": info.path, "slicing": info.slices}
|
||||
)
|
||||
slices = compute_slices(info, rank, size)
|
||||
result = compute_contraction(network, slices)
|
||||
result = network.contract(slices=slices)
|
||||
return reduce_result(result, comm, method="MPI"), rank
|
||||
|
||||
|
||||
@@ -297,7 +294,7 @@ def dense_vector_tn_nccl(qibo_circ, datatype, n_samples=8):
|
||||
optimize={"path": info.path, "slicing": info.slices}
|
||||
)
|
||||
slices = compute_slices(info, rank, size)
|
||||
result = compute_contraction(network, slices)
|
||||
result = network.contract(slices=slices)
|
||||
return reduce_result(result, comm_nccl, method="NCCL"), rank
|
||||
|
||||
|
||||
@@ -375,7 +372,7 @@ def expectation_tn_nccl(qibo_circ, datatype, observable, n_samples=8):
|
||||
slices = compute_slices(info, rank, size)
|
||||
|
||||
# Contract the group of slices the process is responsible for.
|
||||
result = compute_contraction(network, slices)
|
||||
result = network.contract(slices=slices)
|
||||
|
||||
# Sum the partial contribution from each process on root.
|
||||
result = reduce_result(result, comm_nccl, method="NCCL", root=0)
|
||||
@@ -444,7 +441,7 @@ def expectation_tn_MPI(qibo_circ, datatype, observable, n_samples=8):
|
||||
slices = compute_slices(info, rank, size)
|
||||
|
||||
# Perform contraction
|
||||
result = compute_contraction(network, slices)
|
||||
result = network.contract(slices=slices)
|
||||
|
||||
# Sum the partial contribution from each process on root.
|
||||
result = reduce_result(result, comm, method="MPI", root=0)
|
||||
|
||||
Reference in New Issue
Block a user