Update dense vector tn nccl

This commit is contained in:
tankya2
2024-10-04 14:38:39 +08:00
parent 4cc59564cf
commit 05cd1001ca

View File

@@ -74,7 +74,7 @@ def dense_vector_tn_MPI(qibo_circ, datatype, n_samples=8):
device_id = rank % getDeviceCount() device_id = rank % getDeviceCount()
cp.cuda.Device(device_id).use() cp.cuda.Device(device_id).use()
mempool = cp.get_default_memory_pool() mempool = cp.get_default_memory_pool()
# Perform circuit conversion # Perform circuit conversion
if rank == 0: if rank == 0:
myconvertor = QiboCircuitToEinsum(qibo_circ, dtype=datatype) myconvertor = QiboCircuitToEinsum(qibo_circ, dtype=datatype)
@@ -127,7 +127,7 @@ def dense_vector_tn_MPI(qibo_circ, datatype, n_samples=8):
del network del network
mempool.free_all_blocks() mempool.free_all_blocks()
return result, rank return result, rank
@@ -163,6 +163,7 @@ def dense_vector_tn_nccl(qibo_circ, datatype, n_samples=8):
device_id = rank % getDeviceCount() device_id = rank % getDeviceCount()
cp.cuda.Device(device_id).use() cp.cuda.Device(device_id).use()
mempool = cp.get_default_memory_pool()
# Set up the NCCL communicator. # Set up the NCCL communicator.
nccl_id = nccl.get_unique_id() if rank == root else None nccl_id = nccl.get_unique_id() if rank == root else None
@@ -172,6 +173,14 @@ def dense_vector_tn_nccl(qibo_circ, datatype, n_samples=8):
# Perform circuit conversion # Perform circuit conversion
myconvertor = QiboCircuitToEinsum(qibo_circ, dtype=datatype) myconvertor = QiboCircuitToEinsum(qibo_circ, dtype=datatype)
operands = myconvertor.state_vector_operands() operands = myconvertor.state_vector_operands()
# Perform circuit conversion
if rank == 0:
myconvertor = QiboCircuitToEinsum(qibo_circ, dtype=datatype)
operands = myconvertor.state_vector_operands()
else:
operands = None
operands = comm_mpi.bcast(operands, root)
network = Network(*operands) network = Network(*operands)
@@ -221,6 +230,9 @@ def dense_vector_tn_nccl(qibo_circ, datatype, n_samples=8):
stream_ptr, stream_ptr,
) )
del network
mempool.free_all_blocks()
return result, rank return result, rank