Fix GPU RK4 boundary and sync correctness

This commit is contained in:
2026-04-12 12:13:47 +08:00
parent b78874ef21
commit d9287ea530
4 changed files with 134 additions and 30 deletions

View File

@@ -2,14 +2,15 @@
#ifdef USE_GPU
#include <algorithm>
#include <cmath>
#include <cstdlib>
#include <vector>
#include "bssn_class.h"
#include "bssn_cuda_ops.h"
#include "bssn_gpu.h"
#include "bssn_macro.h"
#include "rungekutta4_rout.h"
void bssn_class::Step_MainPath_GPU(int lev, int YN)
{
@@ -56,11 +57,6 @@ void bssn_class::Step_MainPath_GPU(int lev, int YN)
const bool BB = fgt(PhysTime, StartTime, dT_lev / 2);
(void)BB;
#if (MAPBH == 0)
const bool need_host_stage_sync = (BH_num > 0 && lev == GH->levels - 1);
#else
const bool need_host_stage_sync = false;
#endif
double ndeps = (lev < GH->movls) ? numepsb : numepss;
double TRK4 = PhysTime;
int iter_count = 0;
@@ -83,6 +79,8 @@ void bssn_class::Step_MainPath_GPU(int lev, int YN)
patch->bbox[0], patch->bbox[1], patch->bbox[2],
patch->bbox[3], patch->bbox[4], patch->bbox[5],
cg->fgfs[varl0->data->sgfn],
cg->fgfs[phi0->sgfn],
cg->fgfs[Lap0->sgfn],
cg->fgfs[varlb->data->sgfn],
cg->fgfs[varls->data->sgfn],
cg->fgfs[varlr->data->sgfn],
@@ -124,6 +122,28 @@ void bssn_class::Step_MainPath_GPU(int lev, int YN)
}
};
auto stage_download_patch_list =
[&](MyList<var> *var_list) {
MyList<Patch> *patch_it = GH->PatL[lev];
while (patch_it)
{
MyList<Block> *block_it = patch_it->data->blb;
while (block_it)
{
Block *cg = block_it->data;
if (myrank == cg->rank)
stage_download_var_list(cg, var_list);
if (block_it == patch_it->data->ble)
break;
block_it = block_it->next;
}
if (ERROR)
break;
patch_it = patch_it->next;
}
};
auto ensure_stage_device_var_list =
[&](Block *cg, MyList<var> *var_list) {
const int n = cg->shape[0] * cg->shape[1] * cg->shape[2];
@@ -336,8 +356,6 @@ void bssn_class::Step_MainPath_GPU(int lev, int YN)
<< cg->bbox[2] << ":" << cg->bbox[5] << ")" << endl;
ERROR = 1;
}
if (!ERROR && !sync_cache_pre[lev].valid)
stage_download_var_list(cg, SynchList_pre);
}
if (BP == Pp->data->ble)
break;
@@ -346,8 +364,12 @@ void bssn_class::Step_MainPath_GPU(int lev, int YN)
Pp = Pp->next;
}
if (!ERROR && sync_cache_pre[lev].valid && !can_pack_sync_from_device(SynchList_pre, sync_cache_pre[lev]))
refresh_stage_host_before_sync(SynchList_pre, sync_cache_pre[lev]);
if (!ERROR)
{
stage_download_patch_list(SynchList_pre);
if (!ERROR)
bssn_gpu_clear_cached_device_buffers();
}
MPI_Request err_req_pre;
{
@@ -357,8 +379,8 @@ void bssn_class::Step_MainPath_GPU(int lev, int YN)
Parallel::AsyncSyncState async_pre;
Parallel::Sync_start(GH->PatL[lev], SynchList_pre, Symmetry, sync_cache_pre[lev], async_pre);
Parallel::Sync_finish(sync_cache_pre[lev], async_pre, SynchList_pre, Symmetry, need_host_stage_sync);
if (!ERROR && need_host_stage_sync)
Parallel::Sync_finish(sync_cache_pre[lev], async_pre, SynchList_pre, Symmetry, true);
if (!ERROR)
refresh_stage_device_after_sync(SynchList_pre, sync_cache_pre[lev]);
MPI_Wait(&err_req_pre, MPI_STATUS_IGNORE);
@@ -427,8 +449,6 @@ void bssn_class::Step_MainPath_GPU(int lev, int YN)
<< cg->bbox[2] << ":" << cg->bbox[5] << ")" << endl;
ERROR = 1;
}
if (!ERROR && (!sync_cache_cor[lev].valid || iter_count == 3))
stage_download_var_list(cg, SynchList_cor);
}
if (BP == Pp->data->ble)
@@ -438,9 +458,12 @@ void bssn_class::Step_MainPath_GPU(int lev, int YN)
Pp = Pp->next;
}
if (!ERROR && sync_cache_cor[lev].valid && iter_count < 3 &&
!can_pack_sync_from_device(SynchList_cor, sync_cache_cor[lev]))
refresh_stage_host_before_sync(SynchList_cor, sync_cache_cor[lev]);
if (!ERROR)
{
stage_download_patch_list(SynchList_cor);
if (!ERROR)
bssn_gpu_clear_cached_device_buffers();
}
MPI_Request err_req_cor;
{
@@ -450,9 +473,8 @@ void bssn_class::Step_MainPath_GPU(int lev, int YN)
Parallel::AsyncSyncState async_cor;
Parallel::Sync_start(GH->PatL[lev], SynchList_cor, Symmetry, sync_cache_cor[lev], async_cor);
const bool unpack_cor_to_host = (iter_count == 3) || need_host_stage_sync;
Parallel::Sync_finish(sync_cache_cor[lev], async_cor, SynchList_cor, Symmetry, unpack_cor_to_host);
if (!ERROR && iter_count < 3 && unpack_cor_to_host)
Parallel::Sync_finish(sync_cache_cor[lev], async_cor, SynchList_cor, Symmetry, true);
if (!ERROR && iter_count < 3)
refresh_stage_device_after_sync(SynchList_cor, sync_cache_cor[lev]);
MPI_Wait(&err_req_cor, MPI_STATUS_IGNORE);