From 39dad0fd88495bd3db007e9f5fd358b9c4f9e660 Mon Sep 17 00:00:00 2001 From: tankya2 Date: Tue, 26 Aug 2025 17:10:39 +0800 Subject: [PATCH] Fix bug on NCCL reduce. Need to change size to 2x for complex128 dtype --- src/qibotn/eval.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/src/qibotn/eval.py b/src/qibotn/eval.py index bb5b104..8f74770 100644 --- a/src/qibotn/eval.py +++ b/src/qibotn/eval.py @@ -219,11 +219,20 @@ def reduce_result(result, comm, method="MPI", root=0): return comm.reduce(sendobj=result, op=MPI.SUM, root=root) elif method == "NCCL": stream_ptr = cp.cuda.get_current_stream().ptr + if result.dtype == cp.complex128: + count = result.size * 2 # complex128 has 2 float64 numbers + nccl_type = nccl.NCCL_FLOAT64 + elif result.dtype == cp.complex64: + count = result.size * 2 # complex64 has 2 float32 numbers + nccl_type = nccl.NCCL_FLOAT32 + else: + raise TypeError(f"Unsupported dtype for NCCL reduce: {result.dtype}") + comm.reduce( result.data.ptr, result.data.ptr, - result.size, - nccl.NCCL_FLOAT64, + count, + nccl_type, nccl.NCCL_SUM, root, stream_ptr,