Add batched CUDA patch interpolation path

This commit is contained in:
2026-04-09 14:56:01 +08:00
parent ad999e4c5a
commit c47349b7a9
3 changed files with 601 additions and 105 deletions

View File

@@ -10,21 +10,41 @@
#include <vector>
using namespace std;
#include "misc.h"
#include "MPatch.h"
#include "Parallel.h"
#include "fmisc.h"
#ifdef INTERP_LB_PROFILE
#include "interp_lb_profile.h"
#endif
namespace
{
struct InterpBlockView
{
Block *bp;
double llb[dim];
double uub[dim];
#include "misc.h"
#include "MPatch.h"
#include "Parallel.h"
#include "fmisc.h"
#include "bssn_cuda_ops.h"
#ifdef INTERP_LB_PROFILE
#include "interp_lb_profile.h"
#endif
#if defined(__GNUC__) || defined(__clang__)
extern int bssn_cuda_interp_points_batch(const int *ex,
const double *X, const double *Y, const double *Z,
const double *const *fields,
const double *soa_flat,
int num_var,
const double *px, const double *py, const double *pz,
int num_points,
int ordn,
int symmetry,
double *out) __attribute__((weak));
#endif
namespace
{
struct InterpVarDesc
{
int sgfn;
double soa[dim];
};
struct InterpBlockView
{
Block *bp;
double llb[dim];
double uub[dim];
};
struct BlockBinIndex
@@ -154,10 +174,10 @@ void build_block_bin_index(Patch *patch, const double *DH, BlockBinIndex &index)
index.valid = true;
}
int find_block_index_for_point(const BlockBinIndex &index, const double *pox, const double *DH)
{
if (!index.valid)
return -1;
int find_block_index_for_point(const BlockBinIndex &index, const double *pox, const double *DH)
{
if (!index.valid)
return -1;
const int bx = coord_to_bin(pox[0], index.lo[0], index.inv[0], index.bins[0]);
const int by = coord_to_bin(pox[1], index.lo[1], index.inv[1], index.bins[1]);
@@ -175,10 +195,151 @@ int find_block_index_for_point(const BlockBinIndex &index, const double *pox, co
for (size_t bi = 0; bi < index.views.size(); bi++)
if (point_in_block_view(index.views[bi], pox, DH))
return int(bi);
return -1;
}
} // namespace
return -1;
}
void collect_interp_vars(MyList<var> *VarList, vector<InterpVarDesc> &vars)
{
vars.clear();
MyList<var> *varl = VarList;
while (varl)
{
InterpVarDesc desc;
desc.sgfn = varl->data->sgfn;
for (int d = 0; d < dim; ++d)
desc.soa[d] = varl->data->SoA[d];
vars.push_back(desc);
varl = varl->next;
}
}
bool should_try_cuda_interp(int ordn, int num_points, int num_var)
{
#if defined(__GNUC__) || defined(__clang__)
if (!bssn_cuda_interp_points_batch)
return false;
#else
return false;
#endif
if (ordn != 6)
return false;
if (num_points < 32)
return false;
return num_points * num_var >= 256;
}
bool run_cuda_interp_for_block(Block *BP,
const vector<InterpVarDesc> &vars,
const vector<int> &point_ids,
double **XX,
double *Shellf,
int num_var,
int ordn,
int Symmetry)
{
if (!should_try_cuda_interp(ordn, static_cast<int>(point_ids.size()), num_var))
return false;
vector<const double *> field_ptrs(num_var);
vector<double> soa_flat(3 * num_var);
for (int v = 0; v < num_var; ++v)
{
field_ptrs[v] = BP->fgfs[vars[v].sgfn];
for (int d = 0; d < dim; ++d)
soa_flat[3 * v + d] = vars[v].soa[d];
}
const int npts = static_cast<int>(point_ids.size());
vector<double> px(npts), py(npts), pz(npts);
for (int p = 0; p < npts; ++p)
{
const int j = point_ids[p];
px[p] = XX[0][j];
py[p] = XX[1][j];
pz[p] = XX[2][j];
}
vector<double> out(static_cast<size_t>(npts) * static_cast<size_t>(num_var));
if (bssn_cuda_interp_points_batch(BP->shape,
BP->X[0], BP->X[1], BP->X[2],
field_ptrs.data(),
soa_flat.data(),
num_var,
px.data(), py.data(), pz.data(),
npts,
ordn,
Symmetry,
out.data()) != 0)
{
return false;
}
for (int p = 0; p < npts; ++p)
{
const int j = point_ids[p];
memcpy(Shellf + j * num_var, out.data() + p * num_var, sizeof(double) * num_var);
}
return true;
}
void run_cpu_interp_for_block(Block *BP,
const vector<InterpVarDesc> &vars,
const vector<int> &point_ids,
double **XX,
double *Shellf,
int num_var,
int ordn,
int Symmetry)
{
for (size_t p = 0; p < point_ids.size(); ++p)
{
const int j = point_ids[p];
double x = XX[0][j];
double y = XX[1][j];
double z = XX[2][j];
int ordn_local = ordn;
int symmetry_local = Symmetry;
for (int v = 0; v < num_var; ++v)
{
f_global_interp(BP->shape, BP->X[0], BP->X[1], BP->X[2],
BP->fgfs[vars[v].sgfn], Shellf[j * num_var + v],
x, y, z, ordn_local, const_cast<double *>(vars[v].soa), symmetry_local);
}
}
}
void interpolate_owned_points(MyList<var> *VarList,
int NN, double **XX,
double *Shellf, int Symmetry,
int myrank, int ordn,
const BlockBinIndex &block_index,
const int *owner_rank,
const int *owner_block)
{
vector<InterpVarDesc> vars;
collect_interp_vars(VarList, vars);
const int num_var = static_cast<int>(vars.size());
vector<vector<int>> block_points(block_index.views.size());
for (int j = 0; j < NN; ++j)
{
if (owner_rank[j] == myrank && owner_block[j] >= 0)
block_points[owner_block[j]].push_back(j);
}
for (size_t bi = 0; bi < block_points.size(); ++bi)
{
if (block_points[bi].empty())
continue;
Block *BP = block_index.views[bi].bp;
bool done = run_cuda_interp_for_block(BP, vars, block_points[bi], XX, Shellf, num_var, ordn, Symmetry);
if (!done)
run_cpu_interp_for_block(BP, vars, block_points[bi], XX, Shellf, num_var, ordn, Symmetry);
}
}
} // namespace
Patch::Patch(int DIM, int *shapei, double *bboxi, int levi, bool buflog, int Symmetry) : lev(levi)
{
@@ -523,12 +684,17 @@ void Patch::Interp_Points(MyList<var> *VarList,
memset(Shellf, 0, sizeof(double) * NN * num_var);
// owner_rank[j] records which MPI rank owns point j
// All ranks traverse the same block list so they all agree on ownership
int *owner_rank;
owner_rank = new int[NN];
for (int j = 0; j < NN; j++)
owner_rank[j] = -1;
// owner_rank[j] records which MPI rank owns point j
// All ranks traverse the same block list so they all agree on ownership
int *owner_rank;
owner_rank = new int[NN];
int *owner_block;
owner_block = new int[NN];
for (int j = 0; j < NN; j++)
{
owner_rank[j] = -1;
owner_block[j] = -1;
}
double DH[dim];
for (int i = 0; i < dim; i++)
@@ -558,25 +724,15 @@ void Patch::Interp_Points(MyList<var> *VarList,
}
const int block_i = find_block_index_for_point(block_index, pox, DH);
if (block_i >= 0)
{
Block *BP = block_index.views[block_i].bp;
owner_rank[j] = BP->rank;
if (myrank == BP->rank)
{
//---> interpolation
varl = VarList;
int k = 0;
while (varl) // run along variables
{
f_global_interp(BP->shape, BP->X[0], BP->X[1], BP->X[2], BP->fgfs[varl->data->sgfn], Shellf[j * num_var + k],
pox[0], pox[1], pox[2], ordn, varl->data->SoA, Symmetry);
varl = varl->next;
k++;
}
}
}
}
if (block_i >= 0)
{
Block *BP = block_index.views[block_i].bp;
owner_rank[j] = BP->rank;
owner_block[j] = block_i;
}
}
interpolate_owned_points(VarList, NN, XX, Shellf, Symmetry, myrank, ordn, block_index, owner_rank, owner_block);
// Replace MPI_Allreduce with per-owner MPI_Bcast:
// Group consecutive points by owner rank and broadcast each group.
@@ -631,9 +787,10 @@ void Patch::Interp_Points(MyList<var> *VarList,
MPI_Bcast(Shellf + jstart * num_var, count, MPI_DOUBLE, cur_owner, MPI_COMM_WORLD);
}
}
delete[] owner_rank;
}
delete[] owner_rank;
delete[] owner_block;
}
void Patch::Interp_Points(MyList<var> *VarList,
int NN, double **XX,
double *Shellf, int Symmetry,
@@ -661,11 +818,16 @@ void Patch::Interp_Points(MyList<var> *VarList,
memset(Shellf, 0, sizeof(double) * NN * num_var);
// owner_rank[j] records which MPI rank owns point j
int *owner_rank;
owner_rank = new int[NN];
for (int j = 0; j < NN; j++)
owner_rank[j] = -1;
// owner_rank[j] records which MPI rank owns point j
int *owner_rank;
owner_rank = new int[NN];
int *owner_block;
owner_block = new int[NN];
for (int j = 0; j < NN; j++)
{
owner_rank[j] = -1;
owner_block[j] = -1;
}
double DH[dim];
for (int i = 0; i < dim; i++)
@@ -696,24 +858,15 @@ void Patch::Interp_Points(MyList<var> *VarList,
}
const int block_i = find_block_index_for_point(block_index, pox, DH);
if (block_i >= 0)
{
Block *BP = block_index.views[block_i].bp;
owner_rank[j] = BP->rank;
if (myrank == BP->rank)
{
varl = VarList;
int k = 0;
while (varl)
{
f_global_interp(BP->shape, BP->X[0], BP->X[1], BP->X[2], BP->fgfs[varl->data->sgfn], Shellf[j * num_var + k],
pox[0], pox[1], pox[2], ordn, varl->data->SoA, Symmetry);
varl = varl->next;
k++;
}
}
}
}
if (block_i >= 0)
{
Block *BP = block_index.views[block_i].bp;
owner_rank[j] = BP->rank;
owner_block[j] = block_i;
}
}
interpolate_owned_points(VarList, NN, XX, Shellf, Symmetry, myrank, ordn, block_index, owner_rank, owner_block);
#ifdef INTERP_LB_PROFILE
double t_interp_end = MPI_Wtime();
@@ -873,9 +1026,10 @@ void Patch::Interp_Points(MyList<var> *VarList,
delete[] send_offset;
delete[] recv_offset;
delete[] send_count;
delete[] recv_count;
delete[] consumer_rank;
delete[] owner_rank;
delete[] recv_count;
delete[] consumer_rank;
delete[] owner_rank;
delete[] owner_block;
#ifdef INTERP_LB_PROFILE
{
@@ -923,11 +1077,16 @@ void Patch::Interp_Points(MyList<var> *VarList,
memset(Shellf, 0, sizeof(double) * NN * num_var);
// owner_rank[j] stores the global rank that owns point j
int *owner_rank;
owner_rank = new int[NN];
for (int j = 0; j < NN; j++)
owner_rank[j] = -1;
// owner_rank[j] stores the global rank that owns point j
int *owner_rank;
owner_rank = new int[NN];
int *owner_block;
owner_block = new int[NN];
for (int j = 0; j < NN; j++)
{
owner_rank[j] = -1;
owner_block[j] = -1;
}
// Build global-to-local rank translation for Comm_here
MPI_Group world_group, local_group;
@@ -962,25 +1121,15 @@ void Patch::Interp_Points(MyList<var> *VarList,
}
const int block_i = find_block_index_for_point(block_index, pox, DH);
if (block_i >= 0)
{
Block *BP = block_index.views[block_i].bp;
owner_rank[j] = BP->rank;
if (myrank == BP->rank)
{
//---> interpolation
varl = VarList;
int k = 0;
while (varl) // run along variables
{
f_global_interp(BP->shape, BP->X[0], BP->X[1], BP->X[2], BP->fgfs[varl->data->sgfn], Shellf[j * num_var + k],
pox[0], pox[1], pox[2], ordn, varl->data->SoA, Symmetry);
varl = varl->next;
k++;
}
}
}
}
if (block_i >= 0)
{
Block *BP = block_index.views[block_i].bp;
owner_rank[j] = BP->rank;
owner_block[j] = block_i;
}
}
interpolate_owned_points(VarList, NN, XX, Shellf, Symmetry, myrank, ordn, block_index, owner_rank, owner_block);
// Collect unique global owner ranks and translate to local ranks in Comm_here
// Then broadcast each owner's points via MPI_Bcast on Comm_here
@@ -1008,10 +1157,11 @@ void Patch::Interp_Points(MyList<var> *VarList,
}
}
MPI_Group_free(&world_group);
MPI_Group_free(&local_group);
delete[] owner_rank;
}
MPI_Group_free(&world_group);
MPI_Group_free(&local_group);
delete[] owner_rank;
delete[] owner_block;
}
void Patch::checkBlock()
{
int myrank;