Fix bug on NCCL reduce. Need to change size to 2x for complex128 dtype
This commit is contained in:
@@ -219,11 +219,20 @@ def reduce_result(result, comm, method="MPI", root=0):
|
|||||||
return comm.reduce(sendobj=result, op=MPI.SUM, root=root)
|
return comm.reduce(sendobj=result, op=MPI.SUM, root=root)
|
||||||
elif method == "NCCL":
|
elif method == "NCCL":
|
||||||
stream_ptr = cp.cuda.get_current_stream().ptr
|
stream_ptr = cp.cuda.get_current_stream().ptr
|
||||||
|
if result.dtype == cp.complex128:
|
||||||
|
count = result.size * 2 # complex128 has 2 float64 numbers
|
||||||
|
nccl_type = nccl.NCCL_FLOAT64
|
||||||
|
elif result.dtype == cp.complex64:
|
||||||
|
count = result.size * 2 # complex64 has 2 float32 numbers
|
||||||
|
nccl_type = nccl.NCCL_FLOAT32
|
||||||
|
else:
|
||||||
|
raise TypeError(f"Unsupported dtype for NCCL reduce: {result.dtype}")
|
||||||
|
|
||||||
comm.reduce(
|
comm.reduce(
|
||||||
result.data.ptr,
|
result.data.ptr,
|
||||||
result.data.ptr,
|
result.data.ptr,
|
||||||
result.size,
|
count,
|
||||||
nccl.NCCL_FLOAT64,
|
nccl_type,
|
||||||
nccl.NCCL_SUM,
|
nccl.NCCL_SUM,
|
||||||
root,
|
root,
|
||||||
stream_ptr,
|
stream_ptr,
|
||||||
|
|||||||
Reference in New Issue
Block a user