Refine batched CUDA interpolation kernel
This commit is contained in:
@@ -98,6 +98,20 @@ inline bool ensure_capacity(CachedIntBuffer &buffer, size_t bytes)
|
||||
return true;
|
||||
}
|
||||
|
||||
inline bool copy_to_device(CachedIntBuffer &dst, const int *src, size_t bytes)
|
||||
{
|
||||
if (!ensure_capacity(dst, bytes))
|
||||
return false;
|
||||
|
||||
cudaError_t err = cudaMemcpy(dst.ptr, src, bytes, cudaMemcpyHostToDevice);
|
||||
if (err != cudaSuccess)
|
||||
{
|
||||
report_cuda_error("cudaMemcpy(H2D int)", err);
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
inline bool copy_to_device(CachedBuffer &dst, const double *src, size_t bytes)
|
||||
{
|
||||
if (!ensure_capacity(dst, bytes))
|
||||
@@ -135,6 +149,72 @@ inline bool copy_to_device_preferring_device(CachedBuffer &dst, const double *sr
|
||||
return true;
|
||||
}
|
||||
|
||||
inline int interp_idint_like_host(double x)
|
||||
{
|
||||
return static_cast<int>(x);
|
||||
}
|
||||
|
||||
inline void lagrange_weights_ord6_host(double x, double *w)
|
||||
{
|
||||
static const double denom[6] = {-120.0, 24.0, -12.0, 12.0, -24.0, 120.0};
|
||||
for (int i = 0; i < 6; ++i)
|
||||
{
|
||||
double num = 1.0;
|
||||
for (int j = 0; j < 6; ++j)
|
||||
{
|
||||
if (j != i)
|
||||
num *= (x - static_cast<double>(j));
|
||||
}
|
||||
w[i] = num / denom[i];
|
||||
}
|
||||
}
|
||||
|
||||
inline bool map_interp_index_host(int logical_idx, int n, int *mapped_idx, int *reflected)
|
||||
{
|
||||
int idx = logical_idx;
|
||||
*reflected = 0;
|
||||
if (idx < 0)
|
||||
{
|
||||
idx = -idx;
|
||||
*reflected = 1;
|
||||
}
|
||||
if (idx < 0 || idx >= n)
|
||||
return false;
|
||||
*mapped_idx = idx;
|
||||
return true;
|
||||
}
|
||||
|
||||
inline bool compute_interp_window_host(double x, const double *coord, int n,
|
||||
int ordn, int allow_reflect,
|
||||
int *start_idx, double *cx)
|
||||
{
|
||||
const double dx = coord[1] - coord[0];
|
||||
const int center = interp_idint_like_host((x - coord[0]) / dx + 0.4);
|
||||
const int cmin = allow_reflect ? (-ordn / 2 + 1) : 0;
|
||||
const int cmax = n - 1;
|
||||
|
||||
int begin = center - ordn / 2 + 1;
|
||||
int end = begin + ordn - 1;
|
||||
if (begin < cmin)
|
||||
{
|
||||
begin = cmin;
|
||||
end = begin + ordn - 1;
|
||||
}
|
||||
if (end > cmax)
|
||||
{
|
||||
end = cmax;
|
||||
begin = end + 1 - ordn;
|
||||
}
|
||||
|
||||
if (begin >= 0)
|
||||
*cx = (x - coord[begin]) / dx;
|
||||
else
|
||||
*cx = (x + coord[-begin]) / dx;
|
||||
|
||||
*start_idx = begin;
|
||||
return (begin >= cmin && end <= cmax);
|
||||
}
|
||||
|
||||
__global__ void enforce_ga_kernel(int n,
|
||||
double *dxx, double *gxy, double *gxz,
|
||||
double *dyy, double *gyz, double *dzz,
|
||||
@@ -295,80 +375,13 @@ __global__ void prolong3_cell_kernel(const double *funcc, double *funf,
|
||||
}
|
||||
}
|
||||
|
||||
__device__ inline int interp_idint_like(double x)
|
||||
{
|
||||
return static_cast<int>(x);
|
||||
}
|
||||
|
||||
__device__ inline void lagrange_weights_ord6(double x, double *w)
|
||||
{
|
||||
const double denom[6] = {-120.0, 24.0, -12.0, 12.0, -24.0, 120.0};
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 6; ++i)
|
||||
{
|
||||
double num = 1.0;
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 6; ++j)
|
||||
{
|
||||
if (j != i)
|
||||
num *= (x - static_cast<double>(j));
|
||||
}
|
||||
w[i] = num / denom[i];
|
||||
}
|
||||
}
|
||||
|
||||
__device__ inline bool map_interp_index(int logical_idx, int n, double soa, int *mapped_idx, double *sign)
|
||||
{
|
||||
int idx = logical_idx;
|
||||
if (idx < 0)
|
||||
{
|
||||
idx = -idx;
|
||||
*sign *= soa;
|
||||
}
|
||||
if (idx < 0 || idx >= n)
|
||||
return false;
|
||||
*mapped_idx = idx;
|
||||
return true;
|
||||
}
|
||||
|
||||
__device__ inline bool compute_interp_window(double x, const double *coord, int n,
|
||||
int ordn, int allow_reflect,
|
||||
int *start_idx, double *cx)
|
||||
{
|
||||
const double dx = coord[1] - coord[0];
|
||||
const int center = interp_idint_like((x - coord[0]) / dx + 0.4);
|
||||
const int cmin = allow_reflect ? (-ordn / 2 + 1) : 0;
|
||||
const int cmax = n - 1;
|
||||
|
||||
int begin = center - ordn / 2 + 1;
|
||||
int end = begin + ordn - 1;
|
||||
if (begin < cmin)
|
||||
{
|
||||
begin = cmin;
|
||||
end = begin + ordn - 1;
|
||||
}
|
||||
if (end > cmax)
|
||||
{
|
||||
end = cmax;
|
||||
begin = end + 1 - ordn;
|
||||
}
|
||||
|
||||
if (begin >= 0)
|
||||
*cx = (x - coord[begin]) / dx;
|
||||
else
|
||||
*cx = (x + coord[-begin]) / dx;
|
||||
|
||||
*start_idx = begin;
|
||||
return (begin >= cmin && end <= cmax);
|
||||
}
|
||||
|
||||
__global__ void interp_points_ord6_kernel(int num_points, int num_var,
|
||||
int nx, int ny, int nz,
|
||||
const double *X, const double *Y, const double *Z,
|
||||
int nx, int ny,
|
||||
const double *const *fields,
|
||||
const double *soa_flat,
|
||||
const double *px, const double *py, const double *pz,
|
||||
int symmetry,
|
||||
const int *stencil_idx,
|
||||
const int *stencil_reflect,
|
||||
const double *stencil_weights,
|
||||
double *out,
|
||||
int *error_flag)
|
||||
{
|
||||
@@ -379,62 +392,38 @@ __global__ void interp_points_ord6_kernel(int num_points, int num_var,
|
||||
const int var_id = idx - point_id * num_var;
|
||||
const double *field = fields[var_id];
|
||||
const double *soa = soa_flat + 3 * var_id;
|
||||
|
||||
const double dx = X[1] - X[0];
|
||||
const double dy = Y[1] - Y[0];
|
||||
const double dz = Z[1] - Z[0];
|
||||
const int allow_reflect_x = (symmetry == 2 && fabs(X[0]) < dx);
|
||||
const int allow_reflect_y = (symmetry == 2 && fabs(Y[0]) < dy);
|
||||
const int allow_reflect_z = (symmetry != 0 && fabs(Z[0]) < dz);
|
||||
|
||||
int start_x = 0, start_y = 0, start_z = 0;
|
||||
double cx = 0.0, cy = 0.0, cz = 0.0;
|
||||
const bool ok_x = compute_interp_window(px[point_id], X, nx, 6, allow_reflect_x, &start_x, &cx);
|
||||
const bool ok_y = compute_interp_window(py[point_id], Y, ny, 6, allow_reflect_y, &start_y, &cy);
|
||||
const bool ok_z = compute_interp_window(pz[point_id], Z, nz, 6, allow_reflect_z, &start_z, &cz);
|
||||
if (!ok_x || !ok_y || !ok_z)
|
||||
{
|
||||
atomicExch(error_flag, 1);
|
||||
out[idx] = 0.0;
|
||||
continue;
|
||||
}
|
||||
|
||||
double wx[6], wy[6], wz[6];
|
||||
lagrange_weights_ord6(cx, wx);
|
||||
lagrange_weights_ord6(cy, wy);
|
||||
lagrange_weights_ord6(cz, wz);
|
||||
const int *idxp = stencil_idx + point_id * 18;
|
||||
const int *refp = stencil_reflect + point_id * 18;
|
||||
const double *wp = stencil_weights + point_id * 18;
|
||||
|
||||
double value = 0.0;
|
||||
#pragma unroll
|
||||
for (int kz = 0; kz < 6; ++kz)
|
||||
{
|
||||
const int z_idx = idxp[12 + kz];
|
||||
const double z_sign = refp[12 + kz] ? soa[2] : 1.0;
|
||||
double yz_sum = 0.0;
|
||||
#pragma unroll
|
||||
for (int jy = 0; jy < 6; ++jy)
|
||||
{
|
||||
const int y_idx = idxp[6 + jy];
|
||||
const double y_sign = refp[6 + jy] ? soa[1] : 1.0;
|
||||
double x_sum = 0.0;
|
||||
#pragma unroll
|
||||
for (int ix = 0; ix < 6; ++ix)
|
||||
{
|
||||
double sign = 1.0;
|
||||
int sx = 0, sy = 0, sz = 0;
|
||||
const bool ok_map =
|
||||
map_interp_index(start_x + ix, nx, soa[0], &sx, &sign) &&
|
||||
map_interp_index(start_y + jy, ny, soa[1], &sy, &sign) &&
|
||||
map_interp_index(start_z + kz, nz, soa[2], &sz, &sign);
|
||||
if (!ok_map)
|
||||
{
|
||||
atomicExch(error_flag, 1);
|
||||
continue;
|
||||
}
|
||||
x_sum += wx[ix] * sign * field[index3(sx, sy, sz, nx, ny)];
|
||||
const int x_idx = idxp[ix];
|
||||
const double x_sign = refp[ix] ? soa[0] : 1.0;
|
||||
const double sign = x_sign * y_sign * z_sign;
|
||||
x_sum += wp[ix] * sign * field[index3(x_idx, y_idx, z_idx, nx, ny)];
|
||||
}
|
||||
yz_sum += wy[jy] * x_sum;
|
||||
yz_sum += wp[6 + jy] * x_sum;
|
||||
}
|
||||
value += wz[kz] * yz_sum;
|
||||
value += wp[12 + kz] * yz_sum;
|
||||
}
|
||||
out[idx] = value;
|
||||
}
|
||||
(void)error_flag;
|
||||
}
|
||||
|
||||
__global__ void rk4_kernel(int n, double dT,
|
||||
@@ -964,19 +953,14 @@ int bssn_cuda_interp_points_batch(const int *ex,
|
||||
|
||||
struct InterpBatchCache
|
||||
{
|
||||
CachedBuffer X, Y, Z;
|
||||
CachedBuffer px, py, pz;
|
||||
CachedBuffer out;
|
||||
CachedBuffer soa;
|
||||
CachedBuffer field_ptrs;
|
||||
CachedBuffer weights;
|
||||
CachedIntBuffer indices;
|
||||
CachedIntBuffer reflect;
|
||||
CachedIntBuffer error_flag;
|
||||
std::vector<CachedBuffer> host_field_copies;
|
||||
const double *host_X = nullptr;
|
||||
const double *host_Y = nullptr;
|
||||
const double *host_Z = nullptr;
|
||||
int nx = 0;
|
||||
int ny = 0;
|
||||
int nz = 0;
|
||||
};
|
||||
static thread_local InterpBatchCache cache;
|
||||
|
||||
@@ -984,38 +968,55 @@ int bssn_cuda_interp_points_batch(const int *ex,
|
||||
const int ny = ex[1];
|
||||
const int nz = ex[2];
|
||||
const int field_points = count_points(ex);
|
||||
const size_t coord_bytes_x = static_cast<size_t>(nx) * sizeof(double);
|
||||
const size_t coord_bytes_y = static_cast<size_t>(ny) * sizeof(double);
|
||||
const size_t coord_bytes_z = static_cast<size_t>(nz) * sizeof(double);
|
||||
const size_t field_bytes = static_cast<size_t>(field_points) * sizeof(double);
|
||||
const size_t point_bytes = static_cast<size_t>(num_points) * sizeof(double);
|
||||
const size_t out_bytes = static_cast<size_t>(num_points) * static_cast<size_t>(num_var) * sizeof(double);
|
||||
const size_t soa_bytes = static_cast<size_t>(3 * num_var) * sizeof(double);
|
||||
const size_t ptr_bytes = static_cast<size_t>(num_var) * sizeof(double *);
|
||||
const size_t point_stencil_doubles = static_cast<size_t>(num_points) * 18;
|
||||
const size_t point_stencil_ints = static_cast<size_t>(num_points) * 18;
|
||||
const size_t weights_bytes = point_stencil_doubles * sizeof(double);
|
||||
const size_t indices_bytes = point_stencil_ints * sizeof(int);
|
||||
|
||||
bool ok = true;
|
||||
if (cache.host_X != X || cache.host_Y != Y || cache.host_Z != Z ||
|
||||
cache.nx != nx || cache.ny != ny || cache.nz != nz)
|
||||
std::vector<double> host_weights(point_stencil_doubles);
|
||||
std::vector<int> host_indices(point_stencil_ints);
|
||||
std::vector<int> host_reflect(point_stencil_ints);
|
||||
const double dx = X[1] - X[0];
|
||||
const double dy = Y[1] - Y[0];
|
||||
const double dz = Z[1] - Z[0];
|
||||
const int allow_reflect_x = (symmetry == 2 && std::fabs(X[0]) < dx);
|
||||
const int allow_reflect_y = (symmetry == 2 && std::fabs(Y[0]) < dy);
|
||||
const int allow_reflect_z = (symmetry != 0 && std::fabs(Z[0]) < dz);
|
||||
for (int p = 0; p < num_points; ++p)
|
||||
{
|
||||
ok = copy_to_device(cache.X, X, coord_bytes_x) &&
|
||||
copy_to_device(cache.Y, Y, coord_bytes_y) &&
|
||||
copy_to_device(cache.Z, Z, coord_bytes_z);
|
||||
if (ok)
|
||||
int start_x = 0, start_y = 0, start_z = 0;
|
||||
double cx = 0.0, cy = 0.0, cz = 0.0;
|
||||
const bool ok_x = compute_interp_window_host(px[p], X, nx, ordn, allow_reflect_x, &start_x, &cx);
|
||||
const bool ok_y = compute_interp_window_host(py[p], Y, ny, ordn, allow_reflect_y, &start_y, &cy);
|
||||
const bool ok_z = compute_interp_window_host(pz[p], Z, nz, ordn, allow_reflect_z, &start_z, &cz);
|
||||
if (!ok_x || !ok_y || !ok_z)
|
||||
return 1;
|
||||
|
||||
lagrange_weights_ord6_host(cx, host_weights.data() + p * 18);
|
||||
lagrange_weights_ord6_host(cy, host_weights.data() + p * 18 + 6);
|
||||
lagrange_weights_ord6_host(cz, host_weights.data() + p * 18 + 12);
|
||||
|
||||
for (int i = 0; i < 6; ++i)
|
||||
{
|
||||
cache.host_X = X;
|
||||
cache.host_Y = Y;
|
||||
cache.host_Z = Z;
|
||||
cache.nx = nx;
|
||||
cache.ny = ny;
|
||||
cache.nz = nz;
|
||||
if (!map_interp_index_host(start_x + i, nx, &host_indices[p * 18 + i], &host_reflect[p * 18 + i]))
|
||||
return 1;
|
||||
if (!map_interp_index_host(start_y + i, ny, &host_indices[p * 18 + 6 + i], &host_reflect[p * 18 + 6 + i]))
|
||||
return 1;
|
||||
if (!map_interp_index_host(start_z + i, nz, &host_indices[p * 18 + 12 + i], &host_reflect[p * 18 + 12 + i]))
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
ok = ok &&
|
||||
copy_to_device(cache.px, px, point_bytes) &&
|
||||
copy_to_device(cache.py, py, point_bytes) &&
|
||||
copy_to_device(cache.pz, pz, point_bytes) &&
|
||||
copy_to_device(cache.soa, soa_flat, soa_bytes) &&
|
||||
copy_to_device(cache.weights, host_weights.data(), weights_bytes) &&
|
||||
copy_to_device(cache.indices, host_indices.data(), indices_bytes) &&
|
||||
copy_to_device(cache.reflect, host_reflect.data(), indices_bytes) &&
|
||||
ensure_capacity(cache.out, out_bytes) &&
|
||||
ensure_capacity(cache.field_ptrs, ptr_bytes) &&
|
||||
ensure_capacity(cache.error_flag, sizeof(int));
|
||||
@@ -1060,26 +1061,22 @@ int bssn_cuda_interp_points_batch(const int *ex,
|
||||
|
||||
int nx_local = nx;
|
||||
int ny_local = ny;
|
||||
int nz_local = nz;
|
||||
const double *dX = cache.X.ptr;
|
||||
const double *dY = cache.Y.ptr;
|
||||
const double *dZ = cache.Z.ptr;
|
||||
const double *dpx = cache.px.ptr;
|
||||
const double *dpy = cache.py.ptr;
|
||||
const double *dpz = cache.pz.ptr;
|
||||
const double *dsoa = cache.soa.ptr;
|
||||
const double *const *dfields = reinterpret_cast<const double *const *>(cache.field_ptrs.ptr);
|
||||
const double *dweights = cache.weights.ptr;
|
||||
const int *dindices = cache.indices.ptr;
|
||||
const int *dreflect = cache.reflect.ptr;
|
||||
double *dout = cache.out.ptr;
|
||||
int *derror = cache.error_flag.ptr;
|
||||
|
||||
void *args[] = {
|
||||
&num_points, &num_var,
|
||||
&nx_local, &ny_local, &nz_local,
|
||||
&dX, &dY, &dZ,
|
||||
&nx_local, &ny_local,
|
||||
&dfields,
|
||||
&dsoa,
|
||||
&dpx, &dpy, &dpz,
|
||||
&symmetry,
|
||||
&dindices,
|
||||
&dreflect,
|
||||
&dweights,
|
||||
&dout,
|
||||
&derror};
|
||||
ok = launch_kernel(grid, block, (const void *)interp_points_ord6_kernel, args);
|
||||
|
||||
Reference in New Issue
Block a user