Trim GPU restrict sync overhead

This commit is contained in:
2026-04-12 19:45:34 +08:00
parent ce88c18265
commit d702aa06b9
3 changed files with 178 additions and 158 deletions

View File

@@ -207,6 +207,17 @@ bool parallel_gpu_unpack_segments(const double *data,
return true; return true;
} }
int parallel_var_list_count(MyList<var> *var_list)
{
int count = 0;
while (var_list)
{
count++;
var_list = var_list->next;
}
return count;
}
void parallel_report_mpi_error(const char *context, int errcode, int req_no) void parallel_report_mpi_error(const char *context, int errcode, int req_no)
{ {
char errstr[MPI_MAX_ERROR_STRING]; char errstr[MPI_MAX_ERROR_STRING];
@@ -4646,7 +4657,8 @@ Parallel::SyncCache::SyncCache()
: valid(false), cpusize(0), combined_src(0), combined_dst(0), : valid(false), cpusize(0), combined_src(0), combined_dst(0),
send_lengths(0), recv_lengths(0), send_bufs(0), recv_bufs(0), send_lengths(0), recv_lengths(0), send_bufs(0), recv_bufs(0),
send_buf_caps(0), recv_buf_caps(0), reqs(0), stats(0), max_reqs(0), send_buf_caps(0), recv_buf_caps(0), reqs(0), stats(0), max_reqs(0),
lengths_valid(false), tc_req_node(0), tc_req_is_recv(0), tc_completed(0) lengths_valid(false), lengths_var_count(-1),
tc_req_node(0), tc_req_is_recv(0), tc_completed(0)
{ {
} }
// SyncCache invalidate: free grid segment lists but keep buffers // SyncCache invalidate: free grid segment lists but keep buffers
@@ -4665,6 +4677,7 @@ void Parallel::SyncCache::invalidate()
} }
valid = false; valid = false;
lengths_valid = false; lengths_valid = false;
lengths_var_count = -1;
} }
// SyncCache destroy: free everything // SyncCache destroy: free everything
void Parallel::SyncCache::destroy() void Parallel::SyncCache::destroy()
@@ -4695,6 +4708,8 @@ void Parallel::SyncCache::destroy()
reqs = 0; stats = 0; reqs = 0; stats = 0;
tc_req_node = 0; tc_req_is_recv = 0; tc_completed = 0; tc_req_node = 0; tc_req_is_recv = 0; tc_completed = 0;
cpusize = 0; max_reqs = 0; cpusize = 0; max_reqs = 0;
lengths_valid = false;
lengths_var_count = -1;
} }
// transfer_cached: reuse pre-allocated buffers from SyncCache // transfer_cached: reuse pre-allocated buffers from SyncCache
void Parallel::transfer_cached(MyList<Parallel::gridseg> **src, MyList<Parallel::gridseg> **dst, void Parallel::transfer_cached(MyList<Parallel::gridseg> **src, MyList<Parallel::gridseg> **dst,
@@ -4709,6 +4724,8 @@ void Parallel::transfer_cached(MyList<Parallel::gridseg> **src, MyList<Parallel:
int req_no = 0; int req_no = 0;
int pending_recv = 0; int pending_recv = 0;
const int mpi_tag = parallel_next_transfer_tag(); const int mpi_tag = parallel_next_transfer_tag();
const int current_var_count = parallel_var_list_count(VarList1);
const bool lengths_match = cache.lengths_valid && cache.lengths_var_count == current_var_count;
int node; int node;
int *req_node = cache.tc_req_node; int *req_node = cache.tc_req_node;
int *req_is_recv = cache.tc_req_is_recv; int *req_is_recv = cache.tc_req_is_recv;
@@ -4719,8 +4736,14 @@ void Parallel::transfer_cached(MyList<Parallel::gridseg> **src, MyList<Parallel:
{ {
if (node == myrank) continue; if (node == myrank) continue;
int rlength = data_packer(0, src[node], dst[node], node, UNPACK, VarList1, VarList2, Symmetry); int rlength;
cache.recv_lengths[node] = rlength; if (!lengths_match)
{
rlength = data_packer(0, src[node], dst[node], node, UNPACK, VarList1, VarList2, Symmetry);
cache.recv_lengths[node] = rlength;
}
else
rlength = cache.recv_lengths[node];
if (rlength > 0) if (rlength > 0)
{ {
if (rlength > cache.recv_buf_caps[node]) if (rlength > cache.recv_buf_caps[node])
@@ -4738,8 +4761,14 @@ void Parallel::transfer_cached(MyList<Parallel::gridseg> **src, MyList<Parallel:
} }
// Local transfer on this rank. // Local transfer on this rank.
int self_len = data_packer(0, src[myrank], dst[myrank], myrank, PACK, VarList1, VarList2, Symmetry); int self_len;
cache.recv_lengths[myrank] = self_len; if (!lengths_match)
{
self_len = data_packer(0, src[myrank], dst[myrank], myrank, PACK, VarList1, VarList2, Symmetry);
cache.recv_lengths[myrank] = self_len;
}
else
self_len = cache.recv_lengths[myrank];
if (self_len > 0) if (self_len > 0)
{ {
if (self_len > cache.recv_buf_caps[myrank]) if (self_len > cache.recv_buf_caps[myrank])
@@ -4756,8 +4785,14 @@ void Parallel::transfer_cached(MyList<Parallel::gridseg> **src, MyList<Parallel:
{ {
if (node == myrank) continue; if (node == myrank) continue;
int slength = data_packer(0, src[myrank], dst[myrank], node, PACK, VarList1, VarList2, Symmetry); int slength;
cache.send_lengths[node] = slength; if (!lengths_match)
{
slength = data_packer(0, src[myrank], dst[myrank], node, PACK, VarList1, VarList2, Symmetry);
cache.send_lengths[node] = slength;
}
else
slength = cache.send_lengths[node];
if (slength > 0) if (slength > 0)
{ {
if (slength > cache.send_buf_caps[node]) if (slength > cache.send_buf_caps[node])
@@ -4774,6 +4809,9 @@ void Parallel::transfer_cached(MyList<Parallel::gridseg> **src, MyList<Parallel:
} }
} }
cache.lengths_valid = true;
cache.lengths_var_count = current_var_count;
// Unpack as soon as receive completes to reduce pure wait time. // Unpack as soon as receive completes to reduce pure wait time.
while (pending_recv > 0) while (pending_recv > 0)
{ {
@@ -5000,6 +5038,8 @@ void Parallel::Sync_start(MyList<Patch> *PatL, MyList<var> *VarList, int Symmetr
int myrank; int myrank;
MPI_Comm_rank(MPI_COMM_WORLD, &myrank); MPI_Comm_rank(MPI_COMM_WORLD, &myrank);
int cpusize = cache.cpusize; int cpusize = cache.cpusize;
const int current_var_count = parallel_var_list_count(VarList);
const bool lengths_match = cache.lengths_valid && cache.lengths_var_count == current_var_count;
state.req_no = 0; state.req_no = 0;
state.active = true; state.active = true;
state.mpi_tag = parallel_next_transfer_tag(); state.mpi_tag = parallel_next_transfer_tag();
@@ -5017,7 +5057,7 @@ void Parallel::Sync_start(MyList<Patch> *PatL, MyList<var> *VarList, int Symmetr
if (node == myrank) if (node == myrank)
{ {
int length; int length;
if (!cache.lengths_valid) { if (!lengths_match) {
length = data_packer(0, src[myrank], dst[myrank], node, PACK, VarList, VarList, Symmetry); length = data_packer(0, src[myrank], dst[myrank], node, PACK, VarList, VarList, Symmetry);
cache.recv_lengths[node] = length; cache.recv_lengths[node] = length;
} else { } else {
@@ -5040,7 +5080,7 @@ void Parallel::Sync_start(MyList<Patch> *PatL, MyList<var> *VarList, int Symmetr
else else
{ {
int slength; int slength;
if (!cache.lengths_valid) { if (!lengths_match) {
slength = data_packer(0, src[myrank], dst[myrank], node, PACK, VarList, VarList, Symmetry); slength = data_packer(0, src[myrank], dst[myrank], node, PACK, VarList, VarList, Symmetry);
cache.send_lengths[node] = slength; cache.send_lengths[node] = slength;
} else { } else {
@@ -5063,7 +5103,7 @@ void Parallel::Sync_start(MyList<Patch> *PatL, MyList<var> *VarList, int Symmetr
MPI_Isend((void *)cache.send_bufs[node], slength, MPI_DOUBLE, node, state.mpi_tag, MPI_COMM_WORLD, cache.reqs + state.req_no++); MPI_Isend((void *)cache.send_bufs[node], slength, MPI_DOUBLE, node, state.mpi_tag, MPI_COMM_WORLD, cache.reqs + state.req_no++);
} }
int rlength; int rlength;
if (!cache.lengths_valid) { if (!lengths_match) {
rlength = data_packer(0, src[node], dst[node], node, UNPACK, VarList, VarList, Symmetry); rlength = data_packer(0, src[node], dst[node], node, UNPACK, VarList, VarList, Symmetry);
cache.recv_lengths[node] = rlength; cache.recv_lengths[node] = rlength;
} else { } else {
@@ -5085,6 +5125,7 @@ void Parallel::Sync_start(MyList<Patch> *PatL, MyList<var> *VarList, int Symmetr
} }
} }
cache.lengths_valid = true; cache.lengths_valid = true;
cache.lengths_var_count = current_var_count;
} }
// Sync_finish: progressive unpack as receives complete, then wait for sends // Sync_finish: progressive unpack as receives complete, then wait for sends
void Parallel::Sync_finish(SyncCache &cache, AsyncSyncState &state, void Parallel::Sync_finish(SyncCache &cache, AsyncSyncState &state,
@@ -6348,17 +6389,25 @@ void Parallel::OutBdLow2Himix_cached(MyList<Patch> *PatcL, MyList<Patch> *PatfL,
int req_no = 0; int req_no = 0;
int pending_recv = 0; int pending_recv = 0;
const int mpi_tag = parallel_next_transfer_tag(); const int mpi_tag = parallel_next_transfer_tag();
int *req_node = new int[cache.max_reqs]; const int current_var_count = parallel_var_list_count(VarList1);
int *req_is_recv = new int[cache.max_reqs]; const bool lengths_match = cache.lengths_valid && cache.lengths_var_count == current_var_count;
int *completed = new int[cache.max_reqs]; int *req_node = cache.tc_req_node;
int *req_is_recv = cache.tc_req_is_recv;
int *completed = cache.tc_completed;
// Post receives first so peers can progress rendezvous early. // Post receives first so peers can progress rendezvous early.
for (int node = 0; node < cpusize; node++) for (int node = 0; node < cpusize; node++)
{ {
if (node == myrank) continue; if (node == myrank) continue;
int rlength = data_packermix(0, cache.combined_src[node], cache.combined_dst[node], node, UNPACK, VarList1, VarList2, Symmetry); int rlength;
cache.recv_lengths[node] = rlength; if (!lengths_match)
{
rlength = data_packermix(0, cache.combined_src[node], cache.combined_dst[node], node, UNPACK, VarList1, VarList2, Symmetry);
cache.recv_lengths[node] = rlength;
}
else
rlength = cache.recv_lengths[node];
if (rlength > 0) if (rlength > 0)
{ {
if (rlength > cache.recv_buf_caps[node]) if (rlength > cache.recv_buf_caps[node])
@@ -6376,8 +6425,14 @@ void Parallel::OutBdLow2Himix_cached(MyList<Patch> *PatcL, MyList<Patch> *PatfL,
} }
// Local transfer on this rank. // Local transfer on this rank.
int self_len = data_packermix(0, cache.combined_src[myrank], cache.combined_dst[myrank], myrank, PACK, VarList1, VarList2, Symmetry); int self_len;
cache.recv_lengths[myrank] = self_len; if (!lengths_match)
{
self_len = data_packermix(0, cache.combined_src[myrank], cache.combined_dst[myrank], myrank, PACK, VarList1, VarList2, Symmetry);
cache.recv_lengths[myrank] = self_len;
}
else
self_len = cache.recv_lengths[myrank];
if (self_len > 0) if (self_len > 0)
{ {
if (self_len > cache.recv_buf_caps[myrank]) if (self_len > cache.recv_buf_caps[myrank])
@@ -6394,8 +6449,14 @@ void Parallel::OutBdLow2Himix_cached(MyList<Patch> *PatcL, MyList<Patch> *PatfL,
{ {
if (node == myrank) continue; if (node == myrank) continue;
int slength = data_packermix(0, cache.combined_src[myrank], cache.combined_dst[myrank], node, PACK, VarList1, VarList2, Symmetry); int slength;
cache.send_lengths[node] = slength; if (!lengths_match)
{
slength = data_packermix(0, cache.combined_src[myrank], cache.combined_dst[myrank], node, PACK, VarList1, VarList2, Symmetry);
cache.send_lengths[node] = slength;
}
else
slength = cache.send_lengths[node];
if (slength > 0) if (slength > 0)
{ {
if (slength > cache.send_buf_caps[node]) if (slength > cache.send_buf_caps[node])
@@ -6412,6 +6473,9 @@ void Parallel::OutBdLow2Himix_cached(MyList<Patch> *PatcL, MyList<Patch> *PatfL,
} }
} }
cache.lengths_valid = true;
cache.lengths_var_count = current_var_count;
// Unpack as soon as receive completes to reduce pure wait time. // Unpack as soon as receive completes to reduce pure wait time.
while (pending_recv > 0) while (pending_recv > 0)
{ {
@@ -6436,10 +6500,6 @@ void Parallel::OutBdLow2Himix_cached(MyList<Patch> *PatcL, MyList<Patch> *PatfL,
if (self_len > 0) if (self_len > 0)
data_packermix(cache.recv_bufs[myrank], cache.combined_src[myrank], cache.combined_dst[myrank], myrank, UNPACK, VarList1, VarList2, Symmetry); data_packermix(cache.recv_bufs[myrank], cache.combined_src[myrank], cache.combined_dst[myrank], myrank, UNPACK, VarList1, VarList2, Symmetry);
delete[] req_node;
delete[] req_is_recv;
delete[] completed;
} }
// collect all buffer grid segments or blocks for given patch // collect all buffer grid segments or blocks for given patch

View File

@@ -111,6 +111,7 @@ namespace Parallel
MPI_Status *stats; MPI_Status *stats;
int max_reqs; int max_reqs;
bool lengths_valid; bool lengths_valid;
int lengths_var_count;
int *tc_req_node; int *tc_req_node;
int *tc_req_is_recv; int *tc_req_is_recv;
int *tc_completed; int *tc_completed;

View File

@@ -2016,25 +2016,6 @@ int bssn_cuda_prolong3_pack(int wei,
if (!launch_kernel(grid, block, (const void *)prolong3_cell_kernel, args)) if (!launch_kernel(grid, block, (const void *)prolong3_cell_kernel, args))
return 1; return 1;
cudaError_t sync_err = cudaDeviceSynchronize();
if (sync_err != cudaSuccess)
{
std::fprintf(stderr,
"prolong3 debug: symmetry=%d extc=(%d,%d,%d) extf=(%d,%d,%d) "
"imino=%d imaxo=%d jmino=%d jmaxo=%d kmino=%d kmaxo=%d "
"ic_min=%d ic_max=%d jc_min=%d jc_max=%d kc_min=%d kc_max=%d "
"lbc=(%d,%d,%d) lbf=(%d,%d,%d)\n",
symmetry,
extc[0], extc[1], extc[2],
extf[0], extf[1], extf[2],
imino, imaxo, jmino, jmaxo, kmino, kmaxo,
ic_min, ic_max, jc_min, jc_max, kc_min, kc_max,
lbc[0], lbc[1], lbc[2],
lbf[0], lbf[1], lbf[2]);
report_cuda_error("cudaDeviceSynchronize prolong3", sync_err);
return 1;
}
int host_error_flag = 0; int host_error_flag = 0;
err = cudaMemcpy(&host_error_flag, cache.error_flag.ptr, sizeof(int), cudaMemcpyDeviceToHost); err = cudaMemcpy(&host_error_flag, cache.error_flag.ptr, sizeof(int), cudaMemcpyDeviceToHost);
if (err != cudaSuccess) if (err != cudaSuccess)
@@ -2241,28 +2222,6 @@ int bssn_cuda_restrict3_pack(int wei,
if (!launch_kernel(grid, block, (const void *)restrict3_cell_kernel, args)) if (!launch_kernel(grid, block, (const void *)restrict3_cell_kernel, args))
return 1; return 1;
cudaError_t sync_err = cudaDeviceSynchronize();
if (sync_err != cudaSuccess)
{
std::fprintf(stderr,
"restrict3 debug: symmetry=%d extc=(%d,%d,%d) extf=(%d,%d,%d) "
"imino=%d imaxo=%d jmino=%d jmaxo=%d kmino=%d kmaxo=%d "
"imini=%d imaxi=%d jmini=%d jmaxi=%d kmini=%d kmaxi=%d "
"lbc=(%d,%d,%d) lbf=(%d,%d,%d) "
"fi=[%d,%d] fj=[%d,%d] fk=[%d,%d] window=[%d:%d,%d:%d,%d:%d]\n",
symmetry,
extc[0], extc[1], extc[2],
extf[0], extf[1], extf[2],
imino, imaxo, jmino, jmaxo, kmino, kmaxo,
imini, imaxi, jmini, jmaxi, kmini, kmaxi,
lbc[0], lbc[1], lbc[2],
lbf[0], lbf[1], lbf[2],
fi_min, fi_max, fj_min, fj_max, fk_min, fk_max,
ii_lo, ii_hi, jj_lo, jj_hi, kk_lo, kk_hi);
report_cuda_error("cudaDeviceSynchronize restrict3", sync_err);
return 1;
}
int host_error_flag = 0; int host_error_flag = 0;
err = cudaMemcpy(&host_error_flag, cache.error_flag.ptr, sizeof(int), cudaMemcpyDeviceToHost); err = cudaMemcpy(&host_error_flag, cache.error_flag.ptr, sizeof(int), cudaMemcpyDeviceToHost);
if (err != cudaSuccess) if (err != cudaSuccess)