Trim GPU main-path transfer overhead

This commit is contained in:
2026-04-08 20:16:25 +08:00
parent 01ac1f9250
commit a0af9b8804
2 changed files with 74 additions and 48 deletions

View File

@@ -134,6 +134,7 @@ struct GpuRhsCache
const double *last_x = nullptr;
const double *last_y = nullptr;
const double *last_z = nullptr;
bool meta_uploaded = false;
};
GpuRhsCache &gpu_rhs_cache()
@@ -231,9 +232,31 @@ bool register_gpu_rhs_cleanup()
return true;
}
void ensure_gpu_rhs_invariant_symbols()
{
static bool initialized = false;
if (initialized)
return;
double F1o3h = 1.0 / 3.0;
double F2o3h = 2.0 / 3.0;
double F1o6h = 1.0 / 6.0;
double PIh = M_PI;
int step = GRID_DIM * BLOCK_DIM;
cudaMemcpyToSymbol(F1o3, &F1o3h, sizeof(double));
cudaMemcpyToSymbol(F2o3, &F2o3h, sizeof(double));
cudaMemcpyToSymbol(F1o6, &F1o6h, sizeof(double));
cudaMemcpyToSymbol(PI, &PIh, sizeof(double));
cudaMemcpyToSymbol(STEP_SIZE, &step, sizeof(int));
initialized = true;
}
bool prepare_gpu_rhs_cache(GpuRhsCache &cache, int device, int *ex)
{
register_gpu_rhs_cleanup();
ensure_gpu_rhs_invariant_symbols();
const bool shape_changed =
!cache.allocated ||
@@ -261,6 +284,7 @@ bool prepare_gpu_rhs_cache(GpuRhsCache &cache, int device, int *ex)
cache.last_x = nullptr;
cache.last_y = nullptr;
cache.last_z = nullptr;
cache.meta_uploaded = false;
Meta *meta = &cache.meta;
const int matrix_size = cache.matrix_size;
@@ -446,7 +470,24 @@ bool prepare_gpu_rhs_cache(GpuRhsCache &cache, int device, int *ex)
return false;
}
cudaMemcpyToSymbol(metac, meta, sizeof(Meta));
int _1d_size[4];
int _2d_size[4];
int _3d_size[4];
for (int i = 0; i < 4; ++i)
{
_1d_size[i] = ex[0] + i;
_2d_size[i] = _1d_size[i] * (ex[1] + i);
_3d_size[i] = _2d_size[i] * (ex[2] + i);
}
cudaMemcpyToSymbol(ex_c, ex, 3 * sizeof(int));
cudaMemcpyToSymbol(_1D_SIZE, _1d_size, 4 * sizeof(int));
cudaMemcpyToSymbol(_2D_SIZE, _2d_size, 4 * sizeof(int));
cudaMemcpyToSymbol(_3D_SIZE, _3d_size, 4 * sizeof(int));
cache.allocated = true;
cache.meta_uploaded = true;
return true;
}
@@ -2989,46 +3030,21 @@ int gpu_rhs(int calledby, int mpi_rank, int *ex, double &T,double *X, double *Y,
#endif//if (GAUGE == 6 || GAUGE == 7)
//3.1-----for compute_rhs_bssn---------
//cout<<"Size of Meta:"<<sizeof(Meta)<<endl;
cudaMemcpyToSymbol(metac,meta, sizeof(Meta));
cudaMemcpyToSymbol(ex_c,ex, 3*sizeof(int));
cudaMemcpyToSymbol(T_c,&T, sizeof(double));
cudaMemcpyToSymbol(Symmetry_c,&Symmetry, sizeof(int));
cudaMemcpyToSymbol(Lev_c,&Lev, sizeof(int));
cudaMemcpyToSymbol(co_c,&co, sizeof(int));
cudaMemcpyToSymbol(eps_c,&eps, sizeof(double));
double F1o3h = 1.0; F1o3h /= 3.0;
double F2o3h = 2.0; F2o3h /= 3.0;
double F1o6h = 1.0; F1o6h /= 6.0;
double PIh = M_PI;
int step = GRID_DIM * BLOCK_DIM;
double dXh = X[1] - X[0];
double dYh = Y[1] - Y[0];
double dZh = Z[1] - Z[0];
cudaMemcpyToSymbol(F1o3,&F1o3h, sizeof(double));
cudaMemcpyToSymbol(F2o3,&F2o3h, sizeof(double));
cudaMemcpyToSymbol(F1o6,&F1o6h, sizeof(double));
cudaMemcpyToSymbol(PI,&PIh, sizeof(double));
cudaMemcpyToSymbol(STEP_SIZE,&step, sizeof(int));
cudaMemcpyToSymbol(dX,&dXh, sizeof(double));
cudaMemcpyToSymbol(dY,&dYh, sizeof(double));
cudaMemcpyToSymbol(dZ,&dZh, sizeof(double));
int _1d_size[4];
int _2d_size[4];
int _3d_size[4];
for(int i = 0;i<4;++i){
_1d_size[i] = ex[0] + i;
_2d_size[i] = _1d_size[i] * (ex[1]+i);
_3d_size[i] = _2d_size[i] * (ex[2]+i);
//cout<<_1d_size[i]<<' '<<_2d_size[i]<<' '<<_3d_size[i]<<endl;
}
cudaMemcpyToSymbol(_1D_SIZE,_1d_size, 4*sizeof(int));
cudaMemcpyToSymbol(_2D_SIZE,_2d_size, 4*sizeof(int));
cudaMemcpyToSymbol(_3D_SIZE,_3d_size, 4*sizeof(int));
//3.1-----for compute_rhs_bssn---------
//cout<<"Size of Meta:"<<sizeof(Meta)<<endl;
cudaMemcpyToSymbol(T_c,&T, sizeof(double));
cudaMemcpyToSymbol(Symmetry_c,&Symmetry, sizeof(int));
cudaMemcpyToSymbol(Lev_c,&Lev, sizeof(int));
cudaMemcpyToSymbol(co_c,&co, sizeof(int));
cudaMemcpyToSymbol(eps_c,&eps, sizeof(double));
double dXh = X[1] - X[0];
double dYh = Y[1] - Y[0];
double dZh = Z[1] - Z[0];
cudaMemcpyToSymbol(dX,&dXh, sizeof(double));
cudaMemcpyToSymbol(dY,&dYh, sizeof(double));
cudaMemcpyToSymbol(dZ,&dZh, sizeof(double));
//3.2--------for fderivs------------