Update dense vector tn nccl
This commit is contained in:
@@ -74,7 +74,7 @@ def dense_vector_tn_MPI(qibo_circ, datatype, n_samples=8):
|
|||||||
device_id = rank % getDeviceCount()
|
device_id = rank % getDeviceCount()
|
||||||
cp.cuda.Device(device_id).use()
|
cp.cuda.Device(device_id).use()
|
||||||
mempool = cp.get_default_memory_pool()
|
mempool = cp.get_default_memory_pool()
|
||||||
|
|
||||||
# Perform circuit conversion
|
# Perform circuit conversion
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
myconvertor = QiboCircuitToEinsum(qibo_circ, dtype=datatype)
|
myconvertor = QiboCircuitToEinsum(qibo_circ, dtype=datatype)
|
||||||
@@ -127,7 +127,7 @@ def dense_vector_tn_MPI(qibo_circ, datatype, n_samples=8):
|
|||||||
|
|
||||||
del network
|
del network
|
||||||
mempool.free_all_blocks()
|
mempool.free_all_blocks()
|
||||||
|
|
||||||
return result, rank
|
return result, rank
|
||||||
|
|
||||||
|
|
||||||
@@ -163,6 +163,7 @@ def dense_vector_tn_nccl(qibo_circ, datatype, n_samples=8):
|
|||||||
device_id = rank % getDeviceCount()
|
device_id = rank % getDeviceCount()
|
||||||
|
|
||||||
cp.cuda.Device(device_id).use()
|
cp.cuda.Device(device_id).use()
|
||||||
|
mempool = cp.get_default_memory_pool()
|
||||||
|
|
||||||
# Set up the NCCL communicator.
|
# Set up the NCCL communicator.
|
||||||
nccl_id = nccl.get_unique_id() if rank == root else None
|
nccl_id = nccl.get_unique_id() if rank == root else None
|
||||||
@@ -172,6 +173,14 @@ def dense_vector_tn_nccl(qibo_circ, datatype, n_samples=8):
|
|||||||
# Perform circuit conversion
|
# Perform circuit conversion
|
||||||
myconvertor = QiboCircuitToEinsum(qibo_circ, dtype=datatype)
|
myconvertor = QiboCircuitToEinsum(qibo_circ, dtype=datatype)
|
||||||
operands = myconvertor.state_vector_operands()
|
operands = myconvertor.state_vector_operands()
|
||||||
|
# Perform circuit conversion
|
||||||
|
if rank == 0:
|
||||||
|
myconvertor = QiboCircuitToEinsum(qibo_circ, dtype=datatype)
|
||||||
|
operands = myconvertor.state_vector_operands()
|
||||||
|
else:
|
||||||
|
operands = None
|
||||||
|
|
||||||
|
operands = comm_mpi.bcast(operands, root)
|
||||||
|
|
||||||
network = Network(*operands)
|
network = Network(*operands)
|
||||||
|
|
||||||
@@ -221,6 +230,9 @@ def dense_vector_tn_nccl(qibo_circ, datatype, n_samples=8):
|
|||||||
stream_ptr,
|
stream_ptr,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
del network
|
||||||
|
mempool.free_all_blocks()
|
||||||
|
|
||||||
return result, rank
|
return result, rank
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user