From 05cd1001caa9045cd93e0db7e4aa7b1fd852c48d Mon Sep 17 00:00:00 2001 From: tankya2 Date: Fri, 4 Oct 2024 14:38:39 +0800 Subject: [PATCH] Update dense vector tn nccl --- src/qibotn/eval.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/src/qibotn/eval.py b/src/qibotn/eval.py index 8c4bdff..2d7c441 100644 --- a/src/qibotn/eval.py +++ b/src/qibotn/eval.py @@ -74,7 +74,7 @@ def dense_vector_tn_MPI(qibo_circ, datatype, n_samples=8): device_id = rank % getDeviceCount() cp.cuda.Device(device_id).use() mempool = cp.get_default_memory_pool() - + # Perform circuit conversion if rank == 0: myconvertor = QiboCircuitToEinsum(qibo_circ, dtype=datatype) @@ -127,7 +127,7 @@ def dense_vector_tn_MPI(qibo_circ, datatype, n_samples=8): del network mempool.free_all_blocks() - + return result, rank @@ -163,6 +163,7 @@ def dense_vector_tn_nccl(qibo_circ, datatype, n_samples=8): device_id = rank % getDeviceCount() cp.cuda.Device(device_id).use() + mempool = cp.get_default_memory_pool() # Set up the NCCL communicator. nccl_id = nccl.get_unique_id() if rank == root else None @@ -172,6 +173,14 @@ def dense_vector_tn_nccl(qibo_circ, datatype, n_samples=8): # Perform circuit conversion myconvertor = QiboCircuitToEinsum(qibo_circ, dtype=datatype) operands = myconvertor.state_vector_operands() + # Perform circuit conversion + if rank == 0: + myconvertor = QiboCircuitToEinsum(qibo_circ, dtype=datatype) + operands = myconvertor.state_vector_operands() + else: + operands = None + + operands = comm_mpi.bcast(operands, root) network = Network(*operands) @@ -221,6 +230,9 @@ def dense_vector_tn_nccl(qibo_circ, datatype, n_samples=8): stream_ptr, ) + del network + mempool.free_all_blocks() + return result, rank