Fix bug on NCCL reduce. Need to change size to 2x for complex128 dtype

This commit is contained in:
tankya2
2025-08-26 17:10:39 +08:00
parent cf0a539d3d
commit 39dad0fd88

View File

@@ -219,11 +219,20 @@ def reduce_result(result, comm, method="MPI", root=0):
return comm.reduce(sendobj=result, op=MPI.SUM, root=root)
elif method == "NCCL":
stream_ptr = cp.cuda.get_current_stream().ptr
if result.dtype == cp.complex128:
count = result.size * 2 # complex128 has 2 float64 numbers
nccl_type = nccl.NCCL_FLOAT64
elif result.dtype == cp.complex64:
count = result.size * 2 # complex64 has 2 float32 numbers
nccl_type = nccl.NCCL_FLOAT32
else:
raise TypeError(f"Unsupported dtype for NCCL reduce: {result.dtype}")
comm.reduce(
result.data.ptr,
result.data.ptr,
result.size,
nccl.NCCL_FLOAT64,
count,
nccl_type,
nccl.NCCL_SUM,
root,
stream_ptr,