diff --git a/src/qibotn/eval.py b/src/qibotn/eval.py index 8f74770..0629c93 100644 --- a/src/qibotn/eval.py +++ b/src/qibotn/eval.py @@ -220,10 +220,10 @@ def reduce_result(result, comm, method="MPI", root=0): elif method == "NCCL": stream_ptr = cp.cuda.get_current_stream().ptr if result.dtype == cp.complex128: - count = result.size * 2 # complex128 has 2 float64 numbers + count = result.size * 2 # complex128 has 2 float64 numbers nccl_type = nccl.NCCL_FLOAT64 elif result.dtype == cp.complex64: - count = result.size * 2 # complex64 has 2 float32 numbers + count = result.size * 2 # complex64 has 2 float32 numbers nccl_type = nccl.NCCL_FLOAT32 else: raise TypeError(f"Unsupported dtype for NCCL reduce: {result.dtype}")