Add EM GPU fast paths and defaults

This commit is contained in:
2026-05-07 12:18:56 +08:00
parent dd0e20d8c7
commit cb911dec06
6 changed files with 1720 additions and 183 deletions

View File

@@ -438,7 +438,7 @@ int count_bssn_cuda_state_list(MyList<var> *vars)
{
++count;
vars = vars->next;
if (count > BSSN_ESCALAR_CUDA_STATE_COUNT)
if (count > BSSN_EM_CUDA_STATE_COUNT)
return -1;
}
return count;
@@ -449,8 +449,7 @@ bool fill_bssn_cuda_views_count(Block *cg, MyList<var> *vars,
double **host_views)
{
if (!cg || !host_views ||
(state_count != BSSN_CUDA_STATE_COUNT &&
state_count != BSSN_ESCALAR_CUDA_STATE_COUNT))
state_count <= 0 || state_count > BSSN_EM_CUDA_STATE_COUNT)
return false;
int idx = 0;
while (vars && idx < state_count)
@@ -742,7 +741,7 @@ void bssn_cuda_download_level_state(MyList<Patch> *PatL, MyList<var> *vars, int
Block *cg = BP->data;
if (myrank == cg->rank && bssn_cuda_has_resident_state(cg))
{
double *state_out[BSSN_ESCALAR_CUDA_STATE_COUNT];
double *state_out[BSSN_EM_CUDA_STATE_COUNT];
if (!fill_bssn_cuda_views_count(cg, vars, state_count, state_out))
{
cout << "CUDA BSSN state list mismatch on resident state download" << endl;
@@ -750,7 +749,9 @@ void bssn_cuda_download_level_state(MyList<Patch> *PatL, MyList<var> *vars, int
}
const int rc = (state_count == BSSN_ESCALAR_CUDA_STATE_COUNT)
? bssn_escalar_cuda_download_resident_state(cg, cg->shape, state_out)
: bssn_cuda_download_resident_state(cg, cg->shape, state_out);
: ((state_count == BSSN_CUDA_STATE_COUNT)
? bssn_cuda_download_resident_state(cg, cg->shape, state_out)
: bssn_cuda_download_resident_state_count_if_present(cg, cg->shape, state_out, state_count));
if (rc)
{
cout << "CUDA resident state download failed" << endl;
@@ -779,7 +780,7 @@ void bssn_cuda_download_level_state_if_present(MyList<Patch> *PatL, MyList<var>
Block *cg = BP->data;
if (myrank == cg->rank && bssn_cuda_has_resident_state(cg))
{
double *state_out[BSSN_ESCALAR_CUDA_STATE_COUNT];
double *state_out[BSSN_EM_CUDA_STATE_COUNT];
if (!fill_bssn_cuda_views_count(cg, vars, state_count, state_out))
{
cout << "CUDA BSSN state list mismatch on resident state conditional download" << endl;