Add EM GPU fast paths and defaults

This commit is contained in:
2026-05-07 12:18:56 +08:00
parent dd0e20d8c7
commit cb911dec06
6 changed files with 1720 additions and 183 deletions

View File

@@ -18,7 +18,7 @@
#endif
#if USE_CUDA_BSSN
#include "bssn_rhs_cuda.h"
#define AMSS_BSSN_CUDA_MAX_STATE_COUNT BSSN_ESCALAR_CUDA_STATE_COUNT
#define AMSS_BSSN_CUDA_MAX_STATE_COUNT BSSN_EM_CUDA_STATE_COUNT
#endif
#if USE_CUDA_Z4C
#include "z4c_rhs_cuda.h"
@@ -181,8 +181,7 @@ bool cuda_build_bssn_host_views(Block *block,
double **views)
{
if (!block || !vars || !views ||
(state_count != BSSN_CUDA_STATE_COUNT &&
state_count != BSSN_ESCALAR_CUDA_STATE_COUNT))
state_count <= 0 || state_count > AMSS_BSSN_CUDA_MAX_STATE_COUNT)
return false;
MyList<var> *v = vars;
for (int i = 0; i < state_count; ++i)
@@ -200,8 +199,7 @@ bool cuda_build_bssn_soa(MyList<var> *vars,
double *soa_flat)
{
if (!vars || !soa_flat ||
(state_count != BSSN_CUDA_STATE_COUNT &&
state_count != BSSN_ESCALAR_CUDA_STATE_COUNT))
state_count <= 0 || state_count > AMSS_BSSN_CUDA_MAX_STATE_COUNT)
return false;
MyList<var> *v = vars;
for (int i = 0; i < state_count; ++i)
@@ -322,7 +320,7 @@ bool cuda_state_count_direct_supported(int state_count)
#if USE_CUDA_Z4C && (ABEtype == 2)
return state_count == Z4C_CUDA_STATE_COUNT;
#elif USE_CUDA_BSSN
return state_count > 0 && state_count <= BSSN_ESCALAR_CUDA_STATE_COUNT;
return state_count > 0 && state_count <= AMSS_BSSN_CUDA_MAX_STATE_COUNT;
#else
(void)state_count;
return false;
@@ -550,7 +548,8 @@ bool cuda_uncached_device_buffers_enabled(int state_count)
}
if (!enabled)
return false;
if (state_count != BSSN_ESCALAR_CUDA_STATE_COUNT)
if (state_count != BSSN_ESCALAR_CUDA_STATE_COUNT &&
state_count != BSSN_EM_CUDA_STATE_COUNT)
return false;
return cuda_aware_mpi_enabled();
#else
@@ -6136,6 +6135,7 @@ void Parallel::transfer_cached(MyList<Parallel::gridseg> **src, MyList<Parallel:
MyList<var> *VarList1, MyList<var> *VarList2,
int Symmetry, SyncCache &cache)
{
const double t_transfer = sync_profile_enabled() ? MPI_Wtime() : 0.0;
int myrank;
MPI_Comm_size(MPI_COMM_WORLD, &cache.cpusize);
MPI_Comm_rank(MPI_COMM_WORLD, &myrank);
@@ -6324,6 +6324,13 @@ void Parallel::transfer_cached(MyList<Parallel::gridseg> **src, MyList<Parallel:
else
data_packer(cache.recv_bufs[myrank], src[myrank], dst[myrank], myrank, UNPACK, VarList1, VarList2, Symmetry);
}
if (sync_profile_enabled())
{
SyncProfileStats &stats = sync_profile_stats();
stats.finish_calls++;
stats.finish_sec += MPI_Wtime() - t_transfer;
sync_profile_maybe_log();
}
}
void Parallel::Sync_ensure_cache(MyList<Patch> *PatL, int Symmetry, SyncCache &cache)
{