diff --git a/src/qibotn/eval.py b/src/qibotn/eval.py index bb5b104..8f74770 100644 --- a/src/qibotn/eval.py +++ b/src/qibotn/eval.py @@ -219,11 +219,20 @@ def reduce_result(result, comm, method="MPI", root=0): return comm.reduce(sendobj=result, op=MPI.SUM, root=root) elif method == "NCCL": stream_ptr = cp.cuda.get_current_stream().ptr + if result.dtype == cp.complex128: + count = result.size * 2 # complex128 has 2 float64 numbers + nccl_type = nccl.NCCL_FLOAT64 + elif result.dtype == cp.complex64: + count = result.size * 2 # complex64 has 2 float32 numbers + nccl_type = nccl.NCCL_FLOAT32 + else: + raise TypeError(f"Unsupported dtype for NCCL reduce: {result.dtype}") + comm.reduce( result.data.ptr, result.data.ptr, - result.size, - nccl.NCCL_FLOAT64, + count, + nccl_type, nccl.NCCL_SUM, root, stream_ptr,