Optimize BSSN EScalar GPU path baseline

This commit is contained in:
2026-05-02 18:19:15 +08:00
parent 52beb4d153
commit 59a216ad93
13 changed files with 1366 additions and 177 deletions

View File

@@ -79,6 +79,111 @@ int amss_analysis_map_every()
return every;
}
#if USE_CUDA_BSSN
int amss_escalar_split_rp_enabled()
{
static int enabled = -1;
if (enabled < 0)
{
const char *env = getenv("AMSS_ESCALAR_SPLIT_RP");
enabled = (env && atoi(env) != 0) ? 1 : 0;
}
return enabled;
}
int amss_escalar_split_rp_recursive_enabled()
{
static int enabled = -1;
if (enabled < 0)
{
const char *env = getenv("AMSS_ESCALAR_SPLIT_RP_RECURSIVE");
enabled = (env && atoi(env) != 0) ? 1 : 0;
}
return enabled;
}
MyList<var> *clone_var_sublist(MyList<var> *src, int skip, int take)
{
for (int i = 0; i < skip && src; ++i)
src = src->next;
MyList<var> *dst = nullptr;
MyList<var> *tail = nullptr;
int copied = 0;
while (src && (take < 0 || copied < take))
{
MyList<var> *node = new MyList<var>(src->data);
if (!dst)
dst = node;
else
tail->next = node;
tail = node;
src = src->next;
++copied;
}
return dst;
}
void clear_tmp_var_list(MyList<var> *&list)
{
if (list)
{
list->clearList();
list = nullptr;
}
}
int var_list_count(MyList<var> *vars)
{
int count = 0;
while (vars)
{
++count;
vars = vars->next;
}
return count;
}
bool bssn_prefix_views(Block *cg, MyList<var> *vars, double **views)
{
if (!cg || !vars || !views)
return false;
for (int i = 0; i < BSSN_CUDA_STATE_COUNT; ++i)
{
if (!vars)
return false;
views[i] = cg->fgfs[vars->data->sgfn];
if (!views[i])
return false;
vars = vars->next;
}
return true;
}
void download_bssn_prefix_for_list(MyList<Patch> *PatL,
MyList<var> *vars,
int myrank)
{
while (PatL)
{
MyList<Block> *BP = PatL->data->blb;
while (BP)
{
Block *cg = BP->data;
if (myrank == cg->rank)
{
double *views[BSSN_CUDA_STATE_COUNT];
if (bssn_prefix_views(cg, vars, views))
bssn_cuda_download_resident_state_if_present(cg, cg->shape, views);
}
if (BP == PatL->data->ble)
break;
BP = BP->next;
}
PatL = PatL->next;
}
}
#endif
}
// Compile-time switch for per-timestep memory usage collection/printing.
@@ -7000,6 +7105,108 @@ void bssn_class::RestrictProlong(int lev, int YN, bool BB,
// a_stream.setf(ios::left);
#endif
#if USE_CUDA_BSSN && (ABEtype == 1) && (RPB == 0) && (MIXOUTB == 0)
if (lev > 0 && amss_escalar_split_rp_recursive_enabled() && var_list_count(SL) > BSSN_CUDA_STATE_COUNT)
{
MyList<var> *SLb = clone_var_sublist(SL, 0, BSSN_CUDA_STATE_COUNT);
MyList<var> *OLb = clone_var_sublist(OL, 0, BSSN_CUDA_STATE_COUNT);
MyList<var> *corLb = clone_var_sublist(corL, 0, BSSN_CUDA_STATE_COUNT);
MyList<var> *preb = clone_var_sublist(SynchList_pre, 0, BSSN_CUDA_STATE_COUNT);
MyList<var> *SLs = clone_var_sublist(SL, BSSN_CUDA_STATE_COUNT, -1);
MyList<var> *OLs = clone_var_sublist(OL, BSSN_CUDA_STATE_COUNT, -1);
MyList<var> *corLs = clone_var_sublist(corL, BSSN_CUDA_STATE_COUNT, -1);
MyList<var> *pres = clone_var_sublist(SynchList_pre, BSSN_CUDA_STATE_COUNT, -1);
if (lev > trfls && YN == 0)
{
MyList<Patch> *Pp = GH->PatL[lev - 1];
while (Pp)
{
if (BB)
{
Parallel::prepare_inter_time_level(Pp->data, SLb, OLb, corLb, preb, 0);
Parallel::prepare_inter_time_level(Pp->data, SLs, OLs, corLs, pres, 0);
}
else
{
Parallel::prepare_inter_time_level(Pp->data, SLb, OLb, preb, 0);
Parallel::prepare_inter_time_level(Pp->data, SLs, OLs, pres, 0);
}
Pp = Pp->next;
}
#if AMSS_LEGACY_ABE_TRANSFER
Parallel::Restrict(GH->PatL[lev - 1], GH->PatL[lev], SLb, preb, Symmetry);
#else
Parallel::Restrict_cached(GH->PatL[lev - 1], GH->PatL[lev], SLb, preb, Symmetry, sync_cache_restrict[lev]);
#endif
Parallel::Restrict(GH->PatL[lev - 1], GH->PatL[lev], SLs, pres, Symmetry);
#if (RP_SYNC_COARSE_AFTER_RESTRICT == 1)
#if AMSS_LEGACY_ABE_TRANSFER
Parallel::Sync(GH->PatL[lev - 1], preb, Symmetry);
#else
Parallel::Sync_cached(GH->PatL[lev - 1], preb, Symmetry, sync_cache_rp_coarse[lev]);
#endif
Parallel::Sync(GH->PatL[lev - 1], pres, Symmetry);
#endif
#if AMSS_LEGACY_ABE_TRANSFER
Parallel::OutBdLow2Hi(GH->PatL[lev - 1], GH->PatL[lev], preb, SLb, Symmetry);
#else
Parallel::OutBdLow2Hi_cached(GH->PatL[lev - 1], GH->PatL[lev], preb, SLb, Symmetry, sync_cache_outbd[lev]);
#endif
Parallel::OutBdLow2Hi(GH->PatL[lev - 1], GH->PatL[lev], pres, SLs, Symmetry);
}
else
{
#if AMSS_LEGACY_ABE_TRANSFER
Parallel::Restrict(GH->PatL[lev - 1], GH->PatL[lev], SLb, SLb, Symmetry);
#else
Parallel::Restrict_cached(GH->PatL[lev - 1], GH->PatL[lev], SLb, SLb, Symmetry, sync_cache_restrict[lev]);
#endif
Parallel::Restrict(GH->PatL[lev - 1], GH->PatL[lev], SLs, SLs, Symmetry);
#if (RP_SYNC_COARSE_AFTER_RESTRICT == 1)
#if AMSS_LEGACY_ABE_TRANSFER
Parallel::Sync(GH->PatL[lev - 1], SLb, Symmetry);
#else
Parallel::Sync_cached(GH->PatL[lev - 1], SLb, Symmetry, sync_cache_rp_coarse[lev]);
#endif
Parallel::Sync(GH->PatL[lev - 1], SLs, Symmetry);
#endif
#if AMSS_LEGACY_ABE_TRANSFER
Parallel::OutBdLow2Hi(GH->PatL[lev - 1], GH->PatL[lev], SLb, SLb, Symmetry);
#else
Parallel::OutBdLow2Hi_cached(GH->PatL[lev - 1], GH->PatL[lev], SLb, SLb, Symmetry, sync_cache_outbd[lev]);
#endif
Parallel::OutBdLow2Hi(GH->PatL[lev - 1], GH->PatL[lev], SLs, SLs, Symmetry);
}
#if AMSS_LEGACY_ABE_TRANSFER
Parallel::Sync(GH->PatL[lev], SLb, Symmetry);
#else
Parallel::Sync_cached(GH->PatL[lev], SLb, Symmetry, sync_cache_rp_fine[lev]);
#endif
Parallel::Sync(GH->PatL[lev], SLs, Symmetry);
clear_tmp_var_list(SLb);
clear_tmp_var_list(OLb);
clear_tmp_var_list(corLb);
clear_tmp_var_list(preb);
clear_tmp_var_list(SLs);
clear_tmp_var_list(OLs);
clear_tmp_var_list(corLs);
clear_tmp_var_list(pres);
STEP_TIMER_ADD(TB_RESTRICT_PROLONG, timer_restrict_prolong);
return;
}
if (lev > 0 && var_list_count(SL) > BSSN_CUDA_STATE_COUNT)
{
download_bssn_prefix_for_list(GH->PatL[lev], SL, myrank);
download_bssn_prefix_for_list(GH->PatL[lev - 1], SL, myrank);
download_bssn_prefix_for_list(GH->PatL[lev - 1], OL, myrank);
if (BB)
download_bssn_prefix_for_list(GH->PatL[lev - 1], corL, myrank);
}
#endif
if (lev > 0)
{
MyList<Patch> *Pp, *Ppc;
@@ -7355,6 +7562,117 @@ void bssn_class::RestrictProlong(int lev, int YN, bool BB)
// OldStateList 0 -----------
//
// SynchList_cor old -----------
#if USE_CUDA_BSSN && (ABEtype == 1) && (RPB == 0) && (MIXOUTB == 0)
if (lev > 0 && amss_escalar_split_rp_enabled() &&
var_list_count(StateList) > BSSN_CUDA_STATE_COUNT)
{
MyList<var> *StateB = clone_var_sublist(StateList, 0, BSSN_CUDA_STATE_COUNT);
MyList<var> *OldB = clone_var_sublist(OldStateList, 0, BSSN_CUDA_STATE_COUNT);
MyList<var> *PreB = clone_var_sublist(SynchList_pre, 0, BSSN_CUDA_STATE_COUNT);
MyList<var> *CorB = clone_var_sublist(SynchList_cor, 0, BSSN_CUDA_STATE_COUNT);
MyList<var> *StateS = clone_var_sublist(StateList, BSSN_CUDA_STATE_COUNT, -1);
MyList<var> *OldS = clone_var_sublist(OldStateList, BSSN_CUDA_STATE_COUNT, -1);
MyList<var> *PreS = clone_var_sublist(SynchList_pre, BSSN_CUDA_STATE_COUNT, -1);
MyList<var> *CorS = clone_var_sublist(SynchList_cor, BSSN_CUDA_STATE_COUNT, -1);
if (lev > trfls && YN == 0)
{
if (myrank == 0)
cout << "/=: " << GH->Lt[lev - 1] << "," << GH->Lt[lev] + dT_lev << endl;
MyList<Patch> *Pp = GH->PatL[lev - 1];
while (Pp)
{
if (BB)
{
Parallel::prepare_inter_time_level(Pp->data, StateB, OldB, CorB, PreB, 0);
Parallel::prepare_inter_time_level(Pp->data, StateS, OldS, CorS, PreS, 0);
}
else
{
Parallel::prepare_inter_time_level(Pp->data, StateB, OldB, PreB, 0);
Parallel::prepare_inter_time_level(Pp->data, StateS, OldS, PreS, 0);
}
Pp = Pp->next;
}
#if AMSS_LEGACY_ABE_TRANSFER
Parallel::Restrict(GH->PatL[lev - 1], GH->PatL[lev], CorB, PreB, Symmetry);
Parallel::Restrict(GH->PatL[lev - 1], GH->PatL[lev], CorS, PreS, Symmetry);
#else
Parallel::Restrict_cached(GH->PatL[lev - 1], GH->PatL[lev], CorB, PreB, Symmetry, sync_cache_restrict[lev]);
Parallel::Restrict_cached(GH->PatL[lev - 1], GH->PatL[lev], CorS, PreS, Symmetry, sync_cache_restrict[lev]);
#endif
#if AMSS_LEGACY_ABE_TRANSFER
Parallel::Sync(GH->PatL[lev - 1], PreB, Symmetry);
Parallel::Sync(GH->PatL[lev - 1], PreS, Symmetry);
#else
#if (RP_SYNC_COARSE_AFTER_RESTRICT == 1)
Parallel::Sync_cached(GH->PatL[lev - 1], PreB, Symmetry, sync_cache_rp_coarse[lev]);
Parallel::Sync_cached(GH->PatL[lev - 1], PreS, Symmetry, sync_cache_rp_coarse[lev]);
#endif
#endif
#if AMSS_LEGACY_ABE_TRANSFER
Parallel::OutBdLow2Hi(GH->PatL[lev - 1], GH->PatL[lev], PreB, CorB, Symmetry);
Parallel::OutBdLow2Hi(GH->PatL[lev - 1], GH->PatL[lev], PreS, CorS, Symmetry);
#else
Parallel::OutBdLow2Hi_cached(GH->PatL[lev - 1], GH->PatL[lev], PreB, CorB, Symmetry, sync_cache_outbd[lev]);
Parallel::OutBdLow2Hi_cached(GH->PatL[lev - 1], GH->PatL[lev], PreS, CorS, Symmetry, sync_cache_outbd[lev]);
#endif
}
else
{
if (myrank == 0)
cout << "===: " << GH->Lt[lev - 1] << "," << GH->Lt[lev] + dT_lev << endl;
#if AMSS_LEGACY_ABE_TRANSFER
Parallel::Restrict(GH->PatL[lev - 1], GH->PatL[lev], CorB, StateB, Symmetry);
Parallel::Restrict(GH->PatL[lev - 1], GH->PatL[lev], CorS, StateS, Symmetry);
#else
Parallel::Restrict_cached(GH->PatL[lev - 1], GH->PatL[lev], CorB, StateB, Symmetry, sync_cache_restrict[lev]);
Parallel::Restrict_cached(GH->PatL[lev - 1], GH->PatL[lev], CorS, StateS, Symmetry, sync_cache_restrict[lev]);
#endif
#if AMSS_LEGACY_ABE_TRANSFER
Parallel::Sync(GH->PatL[lev - 1], StateB, Symmetry);
Parallel::Sync(GH->PatL[lev - 1], StateS, Symmetry);
#else
#if (RP_SYNC_COARSE_AFTER_RESTRICT == 1)
Parallel::Sync_cached(GH->PatL[lev - 1], StateB, Symmetry, sync_cache_rp_coarse[lev]);
Parallel::Sync_cached(GH->PatL[lev - 1], StateS, Symmetry, sync_cache_rp_coarse[lev]);
#endif
#endif
#if AMSS_LEGACY_ABE_TRANSFER
Parallel::OutBdLow2Hi(GH->PatL[lev - 1], GH->PatL[lev], StateB, CorB, Symmetry);
Parallel::OutBdLow2Hi(GH->PatL[lev - 1], GH->PatL[lev], StateS, CorS, Symmetry);
#else
Parallel::OutBdLow2Hi_cached(GH->PatL[lev - 1], GH->PatL[lev], StateB, CorB, Symmetry, sync_cache_outbd[lev]);
Parallel::OutBdLow2Hi_cached(GH->PatL[lev - 1], GH->PatL[lev], StateS, CorS, Symmetry, sync_cache_outbd[lev]);
#endif
}
#if AMSS_LEGACY_ABE_TRANSFER
Parallel::Sync(GH->PatL[lev], CorB, Symmetry);
Parallel::Sync(GH->PatL[lev], CorS, Symmetry);
#else
Parallel::Sync_cached(GH->PatL[lev], CorB, Symmetry, sync_cache_rp_fine[lev]);
Parallel::Sync_cached(GH->PatL[lev], CorS, Symmetry, sync_cache_rp_fine[lev]);
#endif
clear_tmp_var_list(StateB);
clear_tmp_var_list(OldB);
clear_tmp_var_list(PreB);
clear_tmp_var_list(CorB);
clear_tmp_var_list(StateS);
clear_tmp_var_list(OldS);
clear_tmp_var_list(PreS);
clear_tmp_var_list(CorS);
STEP_TIMER_ADD(TB_RESTRICT_PROLONG, timer_restrict_prolong);
return;
}
#endif
if (lev > 0)
{
MyList<Patch> *Pp, *Ppc;