Cache repeated interpolation plans

This commit is contained in:
2026-04-09 15:21:01 +08:00
parent 06fa643365
commit 42e851d19a
4 changed files with 428 additions and 370 deletions

View File

@@ -953,14 +953,36 @@ int bssn_cuda_interp_points_batch(const int *ex,
struct InterpBatchCache
{
struct StencilCacheEntry
{
const double *X;
const double *Y;
const double *Z;
const double *px;
const double *py;
const double *pz;
int nx;
int ny;
int nz;
int num_points;
int ordn;
int symmetry;
CachedBuffer weights;
CachedIntBuffer indices;
CachedIntBuffer reflect;
StencilCacheEntry()
: X(nullptr), Y(nullptr), Z(nullptr),
px(nullptr), py(nullptr), pz(nullptr),
nx(0), ny(0), nz(0), num_points(0), ordn(0), symmetry(0) {}
};
CachedBuffer out;
CachedBuffer soa;
CachedBuffer field_ptrs;
CachedBuffer weights;
CachedIntBuffer indices;
CachedIntBuffer reflect;
CachedIntBuffer error_flag;
std::vector<CachedBuffer> host_field_copies;
std::vector<StencilCacheEntry> stencil_entries;
};
static thread_local InterpBatchCache cache;
@@ -978,45 +1000,82 @@ int bssn_cuda_interp_points_batch(const int *ex,
const size_t indices_bytes = point_stencil_ints * sizeof(int);
bool ok = true;
std::vector<double> host_weights(point_stencil_doubles);
std::vector<int> host_indices(point_stencil_ints);
std::vector<int> host_reflect(point_stencil_ints);
const double dx = X[1] - X[0];
const double dy = Y[1] - Y[0];
const double dz = Z[1] - Z[0];
const int allow_reflect_x = (symmetry == 2 && std::fabs(X[0]) < dx);
const int allow_reflect_y = (symmetry == 2 && std::fabs(Y[0]) < dy);
const int allow_reflect_z = (symmetry != 0 && std::fabs(Z[0]) < dz);
for (int p = 0; p < num_points; ++p)
InterpBatchCache::StencilCacheEntry *stencil_cache = nullptr;
for (size_t i = 0; i < cache.stencil_entries.size(); ++i)
{
int start_x = 0, start_y = 0, start_z = 0;
double cx = 0.0, cy = 0.0, cz = 0.0;
const bool ok_x = compute_interp_window_host(px[p], X, nx, ordn, allow_reflect_x, &start_x, &cx);
const bool ok_y = compute_interp_window_host(py[p], Y, ny, ordn, allow_reflect_y, &start_y, &cy);
const bool ok_z = compute_interp_window_host(pz[p], Z, nz, ordn, allow_reflect_z, &start_z, &cz);
if (!ok_x || !ok_y || !ok_z)
return 1;
lagrange_weights_ord6_host(cx, host_weights.data() + p * 18);
lagrange_weights_ord6_host(cy, host_weights.data() + p * 18 + 6);
lagrange_weights_ord6_host(cz, host_weights.data() + p * 18 + 12);
for (int i = 0; i < 6; ++i)
InterpBatchCache::StencilCacheEntry &entry = cache.stencil_entries[i];
if (entry.X == X && entry.Y == Y && entry.Z == Z &&
entry.px == px && entry.py == py && entry.pz == pz &&
entry.nx == nx && entry.ny == ny && entry.nz == nz &&
entry.num_points == num_points && entry.ordn == ordn &&
entry.symmetry == symmetry)
{
if (!map_interp_index_host(start_x + i, nx, &host_indices[p * 18 + i], &host_reflect[p * 18 + i]))
return 1;
if (!map_interp_index_host(start_y + i, ny, &host_indices[p * 18 + 6 + i], &host_reflect[p * 18 + 6 + i]))
return 1;
if (!map_interp_index_host(start_z + i, nz, &host_indices[p * 18 + 12 + i], &host_reflect[p * 18 + 12 + i]))
return 1;
stencil_cache = &entry;
break;
}
}
if (!stencil_cache)
{
cache.stencil_entries.push_back(InterpBatchCache::StencilCacheEntry());
stencil_cache = &cache.stencil_entries.back();
stencil_cache->X = X;
stencil_cache->Y = Y;
stencil_cache->Z = Z;
stencil_cache->px = px;
stencil_cache->py = py;
stencil_cache->pz = pz;
stencil_cache->nx = nx;
stencil_cache->ny = ny;
stencil_cache->nz = nz;
stencil_cache->num_points = num_points;
stencil_cache->ordn = ordn;
stencil_cache->symmetry = symmetry;
std::vector<double> host_weights(point_stencil_doubles);
std::vector<int> host_indices(point_stencil_ints);
std::vector<int> host_reflect(point_stencil_ints);
const double dx = X[1] - X[0];
const double dy = Y[1] - Y[0];
const double dz = Z[1] - Z[0];
const int allow_reflect_x = (symmetry == 2 && std::fabs(X[0]) < dx);
const int allow_reflect_y = (symmetry == 2 && std::fabs(Y[0]) < dy);
const int allow_reflect_z = (symmetry != 0 && std::fabs(Z[0]) < dz);
for (int p = 0; p < num_points; ++p)
{
int start_x = 0, start_y = 0, start_z = 0;
double cx = 0.0, cy = 0.0, cz = 0.0;
const bool ok_x = compute_interp_window_host(px[p], X, nx, ordn, allow_reflect_x, &start_x, &cx);
const bool ok_y = compute_interp_window_host(py[p], Y, ny, ordn, allow_reflect_y, &start_y, &cy);
const bool ok_z = compute_interp_window_host(pz[p], Z, nz, ordn, allow_reflect_z, &start_z, &cz);
if (!ok_x || !ok_y || !ok_z)
return 1;
lagrange_weights_ord6_host(cx, host_weights.data() + p * 18);
lagrange_weights_ord6_host(cy, host_weights.data() + p * 18 + 6);
lagrange_weights_ord6_host(cz, host_weights.data() + p * 18 + 12);
for (int i = 0; i < 6; ++i)
{
if (!map_interp_index_host(start_x + i, nx, &host_indices[p * 18 + i], &host_reflect[p * 18 + i]))
return 1;
if (!map_interp_index_host(start_y + i, ny, &host_indices[p * 18 + 6 + i], &host_reflect[p * 18 + 6 + i]))
return 1;
if (!map_interp_index_host(start_z + i, nz, &host_indices[p * 18 + 12 + i], &host_reflect[p * 18 + 12 + i]))
return 1;
}
}
ok = ok &&
copy_to_device(stencil_cache->weights, host_weights.data(), weights_bytes) &&
copy_to_device(stencil_cache->indices, host_indices.data(), indices_bytes) &&
copy_to_device(stencil_cache->reflect, host_reflect.data(), indices_bytes);
if (!ok)
return 1;
}
ok = ok &&
copy_to_device(cache.soa, soa_flat, soa_bytes) &&
copy_to_device(cache.weights, host_weights.data(), weights_bytes) &&
copy_to_device(cache.indices, host_indices.data(), indices_bytes) &&
copy_to_device(cache.reflect, host_reflect.data(), indices_bytes) &&
ensure_capacity(cache.out, out_bytes) &&
ensure_capacity(cache.field_ptrs, ptr_bytes) &&
ensure_capacity(cache.error_flag, sizeof(int));
@@ -1063,9 +1122,9 @@ int bssn_cuda_interp_points_batch(const int *ex,
int ny_local = ny;
const double *dsoa = cache.soa.ptr;
const double *const *dfields = reinterpret_cast<const double *const *>(cache.field_ptrs.ptr);
const double *dweights = cache.weights.ptr;
const int *dindices = cache.indices.ptr;
const int *dreflect = cache.reflect.ptr;
const double *dweights = stencil_cache->weights.ptr;
const int *dindices = stencil_cache->indices.ptr;
const int *dreflect = stencil_cache->reflect.ptr;
double *dout = cache.out.ptr;
int *derror = cache.error_flag.ptr;