Cache repeated interpolation plans
This commit is contained in:
@@ -953,14 +953,36 @@ int bssn_cuda_interp_points_batch(const int *ex,
|
||||
|
||||
struct InterpBatchCache
|
||||
{
|
||||
struct StencilCacheEntry
|
||||
{
|
||||
const double *X;
|
||||
const double *Y;
|
||||
const double *Z;
|
||||
const double *px;
|
||||
const double *py;
|
||||
const double *pz;
|
||||
int nx;
|
||||
int ny;
|
||||
int nz;
|
||||
int num_points;
|
||||
int ordn;
|
||||
int symmetry;
|
||||
CachedBuffer weights;
|
||||
CachedIntBuffer indices;
|
||||
CachedIntBuffer reflect;
|
||||
|
||||
StencilCacheEntry()
|
||||
: X(nullptr), Y(nullptr), Z(nullptr),
|
||||
px(nullptr), py(nullptr), pz(nullptr),
|
||||
nx(0), ny(0), nz(0), num_points(0), ordn(0), symmetry(0) {}
|
||||
};
|
||||
|
||||
CachedBuffer out;
|
||||
CachedBuffer soa;
|
||||
CachedBuffer field_ptrs;
|
||||
CachedBuffer weights;
|
||||
CachedIntBuffer indices;
|
||||
CachedIntBuffer reflect;
|
||||
CachedIntBuffer error_flag;
|
||||
std::vector<CachedBuffer> host_field_copies;
|
||||
std::vector<StencilCacheEntry> stencil_entries;
|
||||
};
|
||||
static thread_local InterpBatchCache cache;
|
||||
|
||||
@@ -978,45 +1000,82 @@ int bssn_cuda_interp_points_batch(const int *ex,
|
||||
const size_t indices_bytes = point_stencil_ints * sizeof(int);
|
||||
|
||||
bool ok = true;
|
||||
std::vector<double> host_weights(point_stencil_doubles);
|
||||
std::vector<int> host_indices(point_stencil_ints);
|
||||
std::vector<int> host_reflect(point_stencil_ints);
|
||||
const double dx = X[1] - X[0];
|
||||
const double dy = Y[1] - Y[0];
|
||||
const double dz = Z[1] - Z[0];
|
||||
const int allow_reflect_x = (symmetry == 2 && std::fabs(X[0]) < dx);
|
||||
const int allow_reflect_y = (symmetry == 2 && std::fabs(Y[0]) < dy);
|
||||
const int allow_reflect_z = (symmetry != 0 && std::fabs(Z[0]) < dz);
|
||||
for (int p = 0; p < num_points; ++p)
|
||||
InterpBatchCache::StencilCacheEntry *stencil_cache = nullptr;
|
||||
for (size_t i = 0; i < cache.stencil_entries.size(); ++i)
|
||||
{
|
||||
int start_x = 0, start_y = 0, start_z = 0;
|
||||
double cx = 0.0, cy = 0.0, cz = 0.0;
|
||||
const bool ok_x = compute_interp_window_host(px[p], X, nx, ordn, allow_reflect_x, &start_x, &cx);
|
||||
const bool ok_y = compute_interp_window_host(py[p], Y, ny, ordn, allow_reflect_y, &start_y, &cy);
|
||||
const bool ok_z = compute_interp_window_host(pz[p], Z, nz, ordn, allow_reflect_z, &start_z, &cz);
|
||||
if (!ok_x || !ok_y || !ok_z)
|
||||
return 1;
|
||||
|
||||
lagrange_weights_ord6_host(cx, host_weights.data() + p * 18);
|
||||
lagrange_weights_ord6_host(cy, host_weights.data() + p * 18 + 6);
|
||||
lagrange_weights_ord6_host(cz, host_weights.data() + p * 18 + 12);
|
||||
|
||||
for (int i = 0; i < 6; ++i)
|
||||
InterpBatchCache::StencilCacheEntry &entry = cache.stencil_entries[i];
|
||||
if (entry.X == X && entry.Y == Y && entry.Z == Z &&
|
||||
entry.px == px && entry.py == py && entry.pz == pz &&
|
||||
entry.nx == nx && entry.ny == ny && entry.nz == nz &&
|
||||
entry.num_points == num_points && entry.ordn == ordn &&
|
||||
entry.symmetry == symmetry)
|
||||
{
|
||||
if (!map_interp_index_host(start_x + i, nx, &host_indices[p * 18 + i], &host_reflect[p * 18 + i]))
|
||||
return 1;
|
||||
if (!map_interp_index_host(start_y + i, ny, &host_indices[p * 18 + 6 + i], &host_reflect[p * 18 + 6 + i]))
|
||||
return 1;
|
||||
if (!map_interp_index_host(start_z + i, nz, &host_indices[p * 18 + 12 + i], &host_reflect[p * 18 + 12 + i]))
|
||||
return 1;
|
||||
stencil_cache = &entry;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (!stencil_cache)
|
||||
{
|
||||
cache.stencil_entries.push_back(InterpBatchCache::StencilCacheEntry());
|
||||
stencil_cache = &cache.stencil_entries.back();
|
||||
stencil_cache->X = X;
|
||||
stencil_cache->Y = Y;
|
||||
stencil_cache->Z = Z;
|
||||
stencil_cache->px = px;
|
||||
stencil_cache->py = py;
|
||||
stencil_cache->pz = pz;
|
||||
stencil_cache->nx = nx;
|
||||
stencil_cache->ny = ny;
|
||||
stencil_cache->nz = nz;
|
||||
stencil_cache->num_points = num_points;
|
||||
stencil_cache->ordn = ordn;
|
||||
stencil_cache->symmetry = symmetry;
|
||||
|
||||
std::vector<double> host_weights(point_stencil_doubles);
|
||||
std::vector<int> host_indices(point_stencil_ints);
|
||||
std::vector<int> host_reflect(point_stencil_ints);
|
||||
const double dx = X[1] - X[0];
|
||||
const double dy = Y[1] - Y[0];
|
||||
const double dz = Z[1] - Z[0];
|
||||
const int allow_reflect_x = (symmetry == 2 && std::fabs(X[0]) < dx);
|
||||
const int allow_reflect_y = (symmetry == 2 && std::fabs(Y[0]) < dy);
|
||||
const int allow_reflect_z = (symmetry != 0 && std::fabs(Z[0]) < dz);
|
||||
for (int p = 0; p < num_points; ++p)
|
||||
{
|
||||
int start_x = 0, start_y = 0, start_z = 0;
|
||||
double cx = 0.0, cy = 0.0, cz = 0.0;
|
||||
const bool ok_x = compute_interp_window_host(px[p], X, nx, ordn, allow_reflect_x, &start_x, &cx);
|
||||
const bool ok_y = compute_interp_window_host(py[p], Y, ny, ordn, allow_reflect_y, &start_y, &cy);
|
||||
const bool ok_z = compute_interp_window_host(pz[p], Z, nz, ordn, allow_reflect_z, &start_z, &cz);
|
||||
if (!ok_x || !ok_y || !ok_z)
|
||||
return 1;
|
||||
|
||||
lagrange_weights_ord6_host(cx, host_weights.data() + p * 18);
|
||||
lagrange_weights_ord6_host(cy, host_weights.data() + p * 18 + 6);
|
||||
lagrange_weights_ord6_host(cz, host_weights.data() + p * 18 + 12);
|
||||
|
||||
for (int i = 0; i < 6; ++i)
|
||||
{
|
||||
if (!map_interp_index_host(start_x + i, nx, &host_indices[p * 18 + i], &host_reflect[p * 18 + i]))
|
||||
return 1;
|
||||
if (!map_interp_index_host(start_y + i, ny, &host_indices[p * 18 + 6 + i], &host_reflect[p * 18 + 6 + i]))
|
||||
return 1;
|
||||
if (!map_interp_index_host(start_z + i, nz, &host_indices[p * 18 + 12 + i], &host_reflect[p * 18 + 12 + i]))
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
ok = ok &&
|
||||
copy_to_device(stencil_cache->weights, host_weights.data(), weights_bytes) &&
|
||||
copy_to_device(stencil_cache->indices, host_indices.data(), indices_bytes) &&
|
||||
copy_to_device(stencil_cache->reflect, host_reflect.data(), indices_bytes);
|
||||
if (!ok)
|
||||
return 1;
|
||||
}
|
||||
|
||||
ok = ok &&
|
||||
copy_to_device(cache.soa, soa_flat, soa_bytes) &&
|
||||
copy_to_device(cache.weights, host_weights.data(), weights_bytes) &&
|
||||
copy_to_device(cache.indices, host_indices.data(), indices_bytes) &&
|
||||
copy_to_device(cache.reflect, host_reflect.data(), indices_bytes) &&
|
||||
ensure_capacity(cache.out, out_bytes) &&
|
||||
ensure_capacity(cache.field_ptrs, ptr_bytes) &&
|
||||
ensure_capacity(cache.error_flag, sizeof(int));
|
||||
@@ -1063,9 +1122,9 @@ int bssn_cuda_interp_points_batch(const int *ex,
|
||||
int ny_local = ny;
|
||||
const double *dsoa = cache.soa.ptr;
|
||||
const double *const *dfields = reinterpret_cast<const double *const *>(cache.field_ptrs.ptr);
|
||||
const double *dweights = cache.weights.ptr;
|
||||
const int *dindices = cache.indices.ptr;
|
||||
const int *dreflect = cache.reflect.ptr;
|
||||
const double *dweights = stencil_cache->weights.ptr;
|
||||
const int *dindices = stencil_cache->indices.ptr;
|
||||
const int *dreflect = stencil_cache->reflect.ptr;
|
||||
double *dout = cache.out.ptr;
|
||||
int *derror = cache.error_flag.ptr;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user